import torch
from torch.nn import Parameter
from ..inits import uniform
[docs]class DenseGraphConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GraphConv`.
"""
def __init__(self, in_channels, out_channels, aggr='add', bias=True):
assert aggr in ['add', 'mean', 'max']
super(DenseGraphConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aggr = aggr
self.weight = Parameter(torch.Tensor(in_channels, out_channels))
self.lin = torch.nn.Linear(in_channels, out_channels, bias=bias)
self.reset_parameters()
[docs] def reset_parameters(self):
uniform(self.in_channels, self.weight)
self.lin.reset_parameters()
[docs] def forward(self, x, adj, mask=None):
r"""
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
\times N \times F}`, with batch-size :math:`B`, (maximum)
number of nodes :math:`N` for each graph, and feature
dimension :math:`F`.
adj (Tensor): Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B
\times N \times N}`. The adjacency tensor is broadcastable in
the batch dimension, resulting in a shared adjacency matrix for
the complete batch.
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
B, N, _ = adj.size()
out = torch.matmul(adj, x)
out = torch.matmul(out, self.weight)
if self.aggr == 'mean':
out = out / adj.sum(dim=-1, keepdim=True).clamp(min=1)
elif self.aggr == 'max':
out = out.max(dim=-1)[0]
out = out + self.lin(x)
if mask is not None:
out = out * mask.view(B, N, 1).to(x.dtype)
return out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)