from torch_geometric.nn import radius_graph
[docs]class RadiusGraph(object):
r"""Creates edges based on node positions :obj:`pos` to all points within a
given distance.
Args:
r (float): The distance.
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`.
This flag is only needed for CUDA tensors. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
"""
def __init__(self,
r,
loop=False,
max_num_neighbors=32,
flow='source_to_target'):
self.r = r
self.loop = loop
self.max_num_neighbors = max_num_neighbors
self.flow = flow
def __call__(self, data):
data.edge_attr = None
batch = data.batch if 'batch' in data else None
data.edge_index = radius_graph(data.pos, self.r, batch, self.loop,
self.max_num_neighbors, self.flow)
return data
def __repr__(self):
return '{}(r={})'.format(self.__class__.__name__, self.r)