import torch
[docs]class Cartesian(object):
r"""Saves the relative Cartesian coordinates of linked nodes in its edge
attributes.
Args:
norm (bool, optional): If set to :obj:`False`, the output will not be
normalized to the interval :math:`{[0, 1]}^D`.
(default: :obj:`True`)
max_value (float, optional): If set and :obj:`norm=True`, normalization
will be performed based on this value instead of the maximum value
found in the data. (default: :obj:`None`)
cat (bool, optional): If set to :obj:`False`, all existing edge
attributes will be replaced. (default: :obj:`True`)
"""
def __init__(self, norm=True, max_value=None, cat=True):
self.norm = norm
self.max = max_value
self.cat = cat
def __call__(self, data):
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
cart = pos[col] - pos[row]
cart = cart.view(-1, 1) if cart.dim() == 1 else cart
if self.norm and cart.numel() > 0:
max_value = cart.abs().max() if self.max is None else self.max
cart = cart / (2 * max_value) + 0.5
if pseudo is not None and self.cat:
pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)
else:
data.edge_attr = cart
return data
def __repr__(self):
return '{}(norm={}, max_value={})'.format(self.__class__.__name__,
self.norm, self.max)