import os
import os.path as osp
import glob
import torch
import torch.nn.functional as F
import numpy as np
from torch_sparse import coalesce
from torch_geometric.io import read_txt_array
from torch_geometric.utils import remove_self_loops
from torch_geometric.data import Data
names = [
'A', 'graph_indicator', 'node_labels', 'node_attributes'
'edge_labels', 'edge_attributes', 'graph_labels', 'graph_attributes'
]
[docs]def read_tu_data(folder, prefix):
files = glob.glob(osp.join(folder, '{}_*.txt'.format(prefix)))
names = [f.split(os.sep)[-1][len(prefix) + 1:-4] for f in files]
edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1
node_attributes = node_labels = None
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')
if 'node_labels' in names:
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
node_labels = node_labels.unsqueeze(-1)
node_labels = node_labels - node_labels.min(dim=0)[0]
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])
edge_attributes, edge_labels = None, None
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')
if 'edge_labels' in names:
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
edge_labels = edge_labels.unsqueeze(-1)
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
edge_attr = cat([edge_attributes, edge_labels])
y = None
if 'graph_attributes' in names: # Regression problem.
y = read_file(folder, prefix, 'graph_attributes')
elif 'graph_labels' in names: # Classification problem.
y = read_file(folder, prefix, 'graph_labels', torch.long)
_, y = y.unique(sorted=True, return_inverse=True)
num_nodes = edge_index.max().item() + 1 if x is None else x.size(0)
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
num_nodes)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)
return data, slices
def read_file(folder, prefix, name, dtype=None):
path = osp.join(folder, '{}_{}.txt'.format(prefix, name))
return read_txt_array(path, sep=',', dtype=dtype)
def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None
def split(data, batch):
node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
node_slice = torch.cat([torch.tensor([0]), node_slice])
row, _ = data.edge_index
edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
edge_slice = torch.cat([torch.tensor([0]), edge_slice])
# Edge indices should start at zero for every graph.
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
data.__num_nodes__ = torch.bincount(batch).tolist()
slices = {'edge_index': edge_slice}
if data.x is not None:
slices['x'] = node_slice
if data.edge_attr is not None:
slices['edge_attr'] = edge_slice
if data.y is not None:
if data.y.size(0) == batch.size(0):
slices['y'] = node_slice
else:
slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)
return data, slices