import torch
from torch_geometric.transforms import LinearTransformation
[docs]class RandomShear(object):
r"""Shears node positions by randomly sampled factors :math:`s` within a
given interval, *e.g.*, resulting in the transformation matrix
.. math::
\begin{bmatrix}
1 & s_{xy} & s_{xz} \\
s_{yx} & 1 & s_{yz} \\
s_{zx} & z_{zy} & 1 \\
\end{bmatrix}
for three-dimensional positions.
Args:
shear (float or int): maximum shearing factor defining the range
:math:`(-\mathrm{shear}, +\mathrm{shear})` to sample from.
"""
def __init__(self, shear):
self.shear = abs(shear)
def __call__(self, data):
dim = data.pos.size(-1)
matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear)
eye = torch.arange(dim, dtype=torch.long)
matrix[eye, eye] = 1
return LinearTransformation(matrix)(data)
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.shear)