from torch.nn import BatchNorm1d
[docs]class BatchNorm(BatchNorm1d):
r"""Applies batch normalization over a batch of node features as described
in the `"Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_
paper
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
\textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
\odot \gamma + \beta
Args:
in_channels (int): Size of each input sample.
eps (float, optional): A value added to the denominator for numerical
stability. (default: :obj:`1e-5`)
momentum (float, optional): The value used for the running mean and
running variance computation. (default: :obj:`0.1`)
affine (bool, optional): If set to :obj:`True`, this module has
learnable affine parameters :math:`\gamma` and :math:`\beta`.
(default: :obj:`True`)
track_running_stats (bool, optional): If set to :obj:`True`, this
module tracks the running mean and variance, and when set to
:obj:`False`, this module does not track such statistics and always
uses batch statistics in both training and eval modes.
(default: :obj:`True`)
"""
def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(BatchNorm, self).__init__(in_channels, eps, momentum, affine,
track_running_stats)
[docs] def forward(self, x):
""""""
return super(BatchNorm, self).forward(x)
def __repr__(self):
return ('{}({}, eps={}, momentum={}, affine={}, '
'track_running_stats={})').format(self.__class__.__name__,
self.num_features, self.eps,
self.momentum, self.affine,
self.track_running_stats)