import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from ..inits import uniform
[docs]class GraphConv(MessagePassing):
r"""The graph neural network operator from the `"Weisfeiler and Leman Go
Neural: Higher-order Graph Neural Networks"
<https://arxiv.org/abs/1810.02244>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}_2 \mathbf{x}_j.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"add"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels, out_channels, aggr='add', bias=True,
**kwargs):
super(GraphConv, self).__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
self.lin = torch.nn.Linear(in_channels, out_channels, bias=bias)
self.reset_parameters()
[docs] def reset_parameters(self):
uniform(self.in_channels, self.weight)
self.lin.reset_parameters()
[docs] def forward(self, x, edge_index, edge_weight=None, size=None):
""""""
h = torch.matmul(x, self.weight)
return self.propagate(edge_index, size=size, x=x, h=h,
edge_weight=edge_weight)
def message(self, h_j, edge_weight):
return h_j if edge_weight is None else edge_weight.view(-1, 1) * h_j
def update(self, aggr_out, x):
return aggr_out + self.lin(x)
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)