from typing import List
import copy
import os.path as osp
import torch
import torch.utils.data
from torch_sparse import SparseTensor, cat
[docs]class ClusterData(torch.utils.data.Dataset):
r"""Clusters/partitions a graph data object into multiple subgraphs, as
motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep
and Large Graph Convolutional Networks"
<https://arxiv.org/abs/1905.07953>`_ paper.
Args:
data (torch_geometric.data.Data): The graph data object.
num_parts (int): The number of partitions.
recursive (bool, optional): If set to :obj:`True`, will use multilevel
recursive bisection instead of multilevel k-way partitioning.
(default: :obj:`False`)
save_dir (string, optional): If set, will save the partitioned data to
the :obj:`save_dir` directory for faster re-use.
"""
def __init__(self, data, num_parts, recursive=False, save_dir=None):
assert data.edge_index is not None
recursive_str = '_recursive' if recursive else ''
filename = f'partition_{num_parts}{recursive_str}.pt'
path = osp.join(save_dir or '', filename)
if save_dir is not None and osp.exists(path):
adj, partptr, perm = torch.load(path)
else:
(row, col), edge_attr = data.edge_index, data.edge_attr
adj = SparseTensor(row=row, col=col, value=edge_attr)
adj, partptr, perm = adj.partition(num_parts, recursive)
if save_dir is not None:
torch.save((adj, partptr, perm), path)
self.data = self.permute_data(data, perm, adj)
self.partptr = partptr
self.perm = perm
def permute_data(self, data, perm, adj):
data = copy.copy(data)
num_nodes = data.num_nodes
for key, item in data:
if item.size(0) == num_nodes:
data[key] = item[perm]
data.edge_index = None
data.edge_attr = None
data.adj = adj
return data
def __len__(self):
return self.partptr.numel() - 1
def __getitem__(self, idx):
start = int(self.partptr[idx])
length = int(self.partptr[idx + 1]) - start
data = copy.copy(self.data)
num_nodes = data.num_nodes
for key, item in data:
if item.size(0) == num_nodes:
data[key] = item.narrow(0, start, length)
data.adj = data.adj.narrow(1, start, length)
row, col, value = data.adj.coo()
data.adj = None
data.edge_index = torch.stack([row, col], dim=0)
data.edge_attr = value
return data
def __repr__(self):
return (f'{self.__class__.__name__}({self.data}, '
f'num_parts={self.num_parts})')
[docs]class ClusterLoader(torch.utils.data.DataLoader):
r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
for Training Deep and Large Graph Convolutional Networks"
<https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
and their between-cluster links from a large-scale graph data object to
form a mini-batch.
Args:
cluster_data (torch_geometric.data.ClusterData): The already
partioned data object.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
"""
def __init__(self, cluster_data, batch_size=1, shuffle=False, **kwargs):
class HelperDataset(torch.utils.data.Dataset):
def __len__(self):
return len(cluster_data)
def __getitem__(self, idx):
start = int(cluster_data.partptr[idx])
length = int(cluster_data.partptr[idx + 1]) - start
data = copy.copy(cluster_data.data)
num_nodes = data.num_nodes
for key, item in data:
if item.size(0) == num_nodes:
data[key] = item.narrow(0, start, length)
return data, idx
def collate(batch):
data_list = [data[0] for data in batch]
parts: List[int] = [data[1] for data in batch]
partptr = cluster_data.partptr
adj = cat([data.adj for data in data_list], dim=0)
adj = adj.t()
adjs = []
for part in parts:
start = partptr[part]
length = partptr[part + 1] - start
adjs.append(adj.narrow(0, start, length))
adj = cat(adjs, dim=0).t()
row, col, value = adj.coo()
data = cluster_data.data.__class__()
data.num_nodes = adj.size(0)
data.edge_index = torch.stack([row, col], dim=0)
data.edge_attr = value
ref = data_list[0]
keys = ref.keys
keys.remove('adj')
for key in keys:
if ref[key].size(0) != ref.adj.size(0):
data[key] = ref[key]
else:
data[key] = torch.cat([d[key] for d in data_list],
dim=ref.__cat_dim__(key, ref[key]))
return data
super(ClusterLoader,
self).__init__(HelperDataset(), batch_size, shuffle,
collate_fn=collate, **kwargs)