import torch
from torch.nn.modules.instancenorm import _InstanceNorm
from torch_scatter import scatter_add
from torch_geometric.utils import degree
[docs]class InstanceNorm(_InstanceNorm):
r"""Applies instance normalization over each individual example in a batch
of node features as described in the `"Instance Normalization: The Missing
Ingredient for Fast Stylization" <https://arxiv.org/abs/1607.08022>`_
paper
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
\textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
\odot \gamma + \beta
Args:
in_channels (int): Size of each input sample.
eps (float, optional): A value added to the denominator for numerical
stability. (default: :obj:`1e-5`)
momentum (float, optional): The value used for the running mean and
running variance computation. (default: :obj:`0.1`)
affine (bool, optional): If set to :obj:`True`, this module has
learnable affine parameters :math:`\gamma` and :math:`\beta`.
(default: :obj:`False`)
track_running_stats (bool, optional): If set to :obj:`True`, this
module tracks the running mean and variance, and when set to
:obj:`False`, this module does not track such statistics and always
uses instance statistics in both training and eval modes.
(default: :obj:`False`)
"""
def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False):
super(InstanceNorm, self).__init__(in_channels, eps, momentum, affine,
track_running_stats)
[docs] def forward(self, x, batch=None):
""""""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
batch_size = batch.max().item() + 1
if self.training or not self.track_running_stats:
count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1)
tmp = scatter_add(x, batch, dim=0, dim_size=batch_size)
mean = tmp / count.clamp(min=1)
tmp = (x - mean[batch])
tmp = scatter_add(tmp * tmp, batch, dim=0, dim_size=batch_size)
var = tmp / count.clamp(min=1)
unbiased_var = tmp / (count - 1).clamp(min=1)
if self.training and self.track_running_stats:
momentum = self.momentum
self.running_mean = (
1 - momentum) * self.running_mean + momentum * mean.mean(dim=0)
self.running_var = (
1 - momentum
) * self.running_var + momentum * unbiased_var.mean(dim=0)
if not self.training and self.track_running_stats:
mean = self.running_mean.view(1, -1).expand(batch_size, -1)
var = self.running_var.view(1, -1).expand(batch_size, -1)
out = (x - mean[batch]) / torch.sqrt(var[batch] + self.eps)
if self.affine:
out = out * self.weight.view(1, -1) + self.bias.view(1, -1)
return out