from torch_geometric.data import Batch
from torch_geometric.utils import scatter_
from .consecutive import consecutive_cluster
from .pool import pool_edge, pool_batch, pool_pos
def _max_pool_x(cluster, x, size=None):
return scatter_('max', x, cluster, dim=0, dim_size=size)
[docs]def max_pool_x(cluster, x, batch, size=None):
r"""Max-Pools node features according to the clustering defined in
:attr:`cluster`.
Args:
cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
\ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
B-1\}}^N`, which assigns each node to a specific example.
size (int, optional): The maximum number of clusters in a single
example. This property is useful to obtain a batch-wise dense
representation, *e.g.* for applying FC layers, but should only be
used if the size of the maximum number of clusters per example is
known in advance. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`) if :attr:`size` is
:obj:`None`, else :class:`Tensor`
"""
if size is not None:
return _max_pool_x(cluster, x, (batch.max().item() + 1) * size)
cluster, perm = consecutive_cluster(cluster)
x = _max_pool_x(cluster, x)
batch = pool_batch(perm, batch)
return x, batch
[docs]def max_pool(cluster, data, transform=None):
r"""Pools and coarsens a graph given by the
:class:`torch_geometric.data.Data` object according to the clustering
defined in :attr:`cluster`.
All nodes within the same cluster will be represented as one node.
Final node features are defined by the *maximum* features of all nodes
within the same cluster, node positions are averaged and edge indices are
defined to be the union of the edge indices of all nodes within the same
cluster.
Args:
cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
\ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
data (Data): Graph data object.
transform (callable, optional): A function/transform that takes in the
coarsened and pooled :obj:`torch_geometric.data.Data` object and
returns a transformed version. (default: :obj:`None`)
:rtype: :class:`torch_geometric.data.Data`
"""
cluster, perm = consecutive_cluster(cluster)
x = _max_pool_x(cluster, data.x)
index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
batch = None if data.batch is None else pool_batch(perm, data.batch)
pos = None if data.pos is None else pool_pos(cluster, data.pos)
data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)
if transform is not None:
data = transform(data)
return data