import re
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_mean
from torch_geometric.nn import voxel_grid
from torch_geometric.nn.pool.consecutive import consecutive_cluster
[docs]class GridSampling(object):
r"""Clusters points into voxels with size :attr:`size`.
Args:
size (float or [float] or Tensor): Size of a voxel (in each dimension).
start (float or [float] or Tensor, optional): Start coordinates of the
grid (in each dimension). If set to :obj:`None`, will be set to the
minimum coordinates found in :obj:`data.pos`.
(default: :obj:`None`)
end (float or [float] or Tensor, optional): End coordinates of the grid
(in each dimension). If set to :obj:`None`, will be set to the
maximum coordinates found in :obj:`data.pos`.
(default: :obj:`None`)
"""
def __init__(self, size, start=None, end=None):
self.size = size
self.start = start
self.end = end
def __call__(self, data):
num_nodes = data.num_nodes
if 'batch' not in data:
batch = data.pos.new_zeros(num_nodes, dtype=torch.long)
else:
batch = data.batch
cluster = voxel_grid(data.pos, batch, self.size, self.start, self.end)
cluster, perm = consecutive_cluster(cluster)
for key, item in data:
if bool(re.search('edge', key)):
raise ValueError(
'GridSampling does not support coarsening of edges')
if torch.is_tensor(item) and item.size(0) == num_nodes:
if key == 'y':
item = F.one_hot(item)
item = scatter_add(item, cluster, dim=0)
data[key] = item.argmax(dim=-1)
elif key == 'batch':
data[key] = item[perm]
else:
data[key] = scatter_mean(item, cluster, dim=0)
return data
def __repr__(self):
return '{}(size={})'.format(self.__class__.__name__, self.size)