import torch_scatter
[docs]def scatter_(name, src, index, dim=0, dim_size=None):
r"""Aggregates all values from the :attr:`src` tensor at the indices
specified in the :attr:`index` tensor along the first dimension.
If multiple indices reference the same location, their contributions
are aggregated according to :attr:`name` (either :obj:`"add"`,
:obj:`"mean"` or :obj:`"max"`).
Args:
name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"`).
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index. (default: :obj:`0`)
dim_size (int, optional): Automatically create output tensor with size
:attr:`dim_size` in the first dimension. If set to :attr:`None`, a
minimal sized output tensor is returned. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
assert name in ['add', 'mean', 'min', 'max']
op = getattr(torch_scatter, 'scatter_{}'.format(name))
out = op(src, index, dim, None, dim_size)
out = out[0] if isinstance(out, tuple) else out
if name == 'max':
out[out < -10000] = 0
elif name == 'min':
out[out > 10000] = 0
return out