import torch
import torch.nn as nn
from torch_geometric.utils import degree
[docs]class GraphSizeNorm(nn.Module):
r"""Applies Graph Size Normalization over each individual graph in a batch
of node features as described in the
`"Benchmarking Graph Neural Networks" <https://arxiv.org/abs/2003.00982>`_
paper
.. math::
\mathbf{x}^{\prime}_i = \frac{\mathbf{x}_i}{\sqrt{|\mathcal{V}|}}
"""
def __init__(self):
super(GraphSizeNorm, self).__init__()
[docs] def forward(self, x, batch=None):
""""""
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5)
return x * inv_sqrt_deg[batch].view(-1, 1)