from __future__ import division
import warnings
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import degree, segregate_self_loops
from torch_geometric.utils.repeat import repeat
from .data import size_repr
try:
from torch_cluster import neighbor_sampler
except ImportError:
neighbor_sampler = None
class Block(object):
def __init__(self, n_id, res_n_id, e_id, edge_index, size):
self.n_id = n_id
self.res_n_id = res_n_id
self.e_id = e_id
self.edge_index = edge_index
self.size = size
def __repr__(self):
info = [(key, getattr(self, key)) for key in self.__dict__]
info = ['{}={}'.format(key, size_repr(item)) for key, item in info]
return '{}({})'.format(self.__class__.__name__, ', '.join(info))
class DataFlow(object):
def __init__(self, n_id, flow='source_to_target'):
self.n_id = n_id
self.flow = flow
self.__last_n_id__ = n_id
self.blocks = []
@property
def batch_size(self):
return self.n_id.size(0)
def append(self, n_id, res_n_id, e_id, edge_index):
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
size = [None, None]
size[i] = self.__last_n_id__.size(0)
size[j] = n_id.size(0)
block = Block(n_id, res_n_id, e_id, edge_index, tuple(size))
self.blocks.append(block)
self.__last_n_id__ = n_id
def __len__(self):
return len(self.blocks)
def __getitem__(self, idx):
return self.blocks[::-1][idx]
def __iter__(self):
for block in self.blocks[::-1]:
yield block
def to(self, device):
for block in self.blocks:
block.edge_index = block.edge_index.to(device)
return self
def __repr__(self):
n_ids = [self.n_id] + [block.n_id for block in self.blocks]
sep = '<-' if self.flow == 'source_to_target' else '->'
info = sep.join([str(n_id.size(0)) for n_id in n_ids])
return '{}({})'.format(self.__class__.__name__, info)
[docs]class NeighborSampler(object):
r"""The neighbor sampler from the `"Inductive Representation Learning on
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper which iterates
over graph nodes in a mini-batch fashion and constructs sampled subgraphs
of size :obj:`num_hops`.
It returns a generator of :obj:`DataFlow` that defines the message
passing flow to the root nodes via a list of :obj:`num_hops` bipartite
graph objects :obj:`edge_index` and the initial start nodes :obj:`n_id`.
Args:
data (torch_geometric.data.Data): The graph data object.
size (int or float or [int] or [float]): The number of neighbors to
sample (for each layer). The value of this parameter can be either
set to be the same for each neighborhood or percentage-based.
num_hops (int): The number of layers to sample.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
drop_last (bool, optional): If set to :obj:`True`, will drop the last
incomplete batch if the number of nodes is not divisible by the
batch size. If set to :obj:`False` and the size of graph is not
divisible by the batch size, the last batch will be smaller.
(default: :obj:`False`)
bipartite (bool, optional): If set to :obj:`False`, will not return a
generator of :obj:`DataFlow` to mark the computation flow, but
instead will return a :obj:`torch_geometric.data.Data` object
holding the subgraph information around each mini-batch.
If set to :obj:`False`, the :obj:`add_self_loops` option is
ignored. (default: :obj:`True`)
add_self_loops (bool, optional): If set to :obj:`True`, will add
self-loops to each sampled neigborhood. (default: :obj:`False`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
"""
def __init__(self, data, size, num_hops, batch_size=1, shuffle=False,
drop_last=False, bipartite=True, add_self_loops=False,
flow='source_to_target'):
if neighbor_sampler is None:
raise ImportError('`NeighborSampler` requires `torch-cluster`.')
self.data = data
self.size = repeat(size, num_hops)
self.num_hops = num_hops
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.bipartite = bipartite
self.add_self_loops = add_self_loops
self.flow = flow
self.edge_index = data.edge_index
self.e_id = torch.arange(self.edge_index.size(1))
if bipartite and add_self_loops:
tmp = segregate_self_loops(self.edge_index, self.e_id)
self.edge_index, self.e_id, self.edge_index_loop = tmp[:3]
self.e_id_loop = self.e_id.new_full((data.num_nodes, ), -1)
self.e_id_loop[tmp[2][0]] = tmp[3]
assert flow in ['source_to_target', 'target_to_source']
self.i, self.j = (0, 1) if flow == 'target_to_source' else (1, 0)
edge_index_i, self.e_assoc = self.edge_index[self.i].sort()
self.edge_index_j = self.edge_index[self.j, self.e_assoc]
deg = degree(edge_index_i, data.num_nodes, dtype=torch.long)
self.cumdeg = torch.cat([deg.new_zeros(1), deg.cumsum(0)])
self.tmp = torch.empty(data.num_nodes, dtype=torch.long)
[docs] def __get_batches__(self, subset=None):
r"""Returns a list of mini-batches from the initial nodes in
:obj:`subset`."""
if subset is None and not self.shuffle:
subset = torch.arange(self.data.num_nodes, dtype=torch.long)
elif subset is None and self.shuffle:
subset = torch.randperm(self.data.num_nodes)
else:
if subset.dtype == torch.bool or subset.dtype == torch.uint8:
subset = subset.nonzero().view(-1)
if self.shuffle:
subset = subset[torch.randperm(subset.size(0))]
subsets = torch.split(subset, self.batch_size)
if self.drop_last and subsets[-1].size(0) < self.batch_size:
subsets = subsets[:-1]
assert len(subsets) > 0
return subsets
[docs] def __produce_bipartite_data_flow__(self, n_id):
r"""Produces a :obj:`DataFlow` object with a bipartite assignment
matrix for a given mini-batch :obj:`n_id`."""
data_flow = DataFlow(n_id, self.flow)
for l in range(self.num_hops):
e_id = neighbor_sampler(n_id, self.cumdeg, self.size[l])
new_n_id = self.edge_index_j.index_select(0, e_id)
e_id = self.e_assoc[e_id]
if self.add_self_loops:
new_n_id = torch.cat([new_n_id, n_id], dim=0)
new_n_id, inv = new_n_id.unique(sorted=False,
return_inverse=True)
res_n_id = inv[-n_id.size(0):]
else:
new_n_id = new_n_id.unique(sorted=False)
res_n_id = None
edges = [None, None]
edge_index_i = self.edge_index[self.i, e_id]
if self.add_self_loops:
edge_index_i = torch.cat([edge_index_i, n_id], dim=0)
self.tmp[n_id] = torch.arange(n_id.size(0))
edges[self.i] = self.tmp[edge_index_i]
edge_index_j = self.edge_index[self.j, e_id]
if self.add_self_loops:
edge_index_j = torch.cat([edge_index_j, n_id], dim=0)
self.tmp[new_n_id] = torch.arange(new_n_id.size(0))
edges[self.j] = self.tmp[edge_index_j]
edge_index = torch.stack(edges, dim=0)
e_id = self.e_id[e_id]
if self.add_self_loops:
if self.edge_index_loop.size(1) == self.data.num_nodes:
# Only set `e_id` if all self-loops were initially passed
# to the graph.
e_id = torch.cat([e_id, self.e_id_loop[n_id]])
else:
e_id = None
if torch_geometric.is_debug_enabled():
warnings.warn(
('Could not add edge identifiers to the DataFlow'
'object due to missing initial self-loops. '
'Please make sure that your graph already '
'contains self-loops in case you want to use '
'edge-conditioned operators.'))
n_id = new_n_id
data_flow.append(n_id, res_n_id, e_id, edge_index)
return data_flow
[docs] def __produce_subgraph__(self, b_id):
r"""Produces a :obj:`Data` object holding the subgraph data for a given
mini-batch :obj:`b_id`."""
n_ids = [b_id]
e_ids = []
edge_indices = []
for l in range(self.num_hops):
e_id = neighbor_sampler(n_ids[-1], self.cumdeg, self.size[l])
n_id = self.edge_index_j.index_select(0, e_id)
n_id = n_id.unique(sorted=False)
n_ids.append(n_id)
e_ids.append(self.e_assoc.index_select(0, e_id))
edge_index = self.data.edge_index.index_select(1, e_ids[-1])
edge_indices.append(edge_index)
n_id = torch.unique(torch.cat(n_ids, dim=0), sorted=False)
self.tmp[n_id] = torch.arange(n_id.size(0))
e_id = torch.cat(e_ids, dim=0)
edge_index = self.tmp[torch.cat(edge_indices, dim=1)]
num_nodes = n_id.size(0)
idx = edge_index[0] * num_nodes + edge_index[1]
idx, inv = idx.unique(sorted=False, return_inverse=True)
edge_index = torch.stack([idx / num_nodes, idx % num_nodes], dim=0)
e_id = e_id.new_zeros(edge_index.size(1)).scatter_(0, inv, e_id)
return Data(edge_index=edge_index, e_id=e_id, n_id=n_id, b_id=b_id,
sub_b_id=self.tmp[b_id], num_nodes=num_nodes)
[docs] def __call__(self, subset=None):
r"""Returns a generator of :obj:`DataFlow` that iterates over the nodes
in :obj:`subset` in a mini-batch fashion.
Args:
subset (LongTensor or BoolTensor, optional): The initial nodes to
propagete messages to. If set to :obj:`None`, will iterate over
all nodes in the graph. (default: :obj:`None`)
"""
if self.bipartite:
produce = self.__produce_bipartite_data_flow__
else:
produce = self.__produce_subgraph__
for n_id in self.__get_batches__(subset):
yield produce(n_id)