import os
import os.path as osp
import glob
import pickle
import torch
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import (InMemoryDataset, Data, download_url,
extract_zip, extract_tar)
from torch_geometric.utils import to_undirected
[docs]class GEDDataset(InMemoryDataset):
r"""The GED datasets from the `"Graph Edit Distance Computation via Graph
Neural Networks" <https://arxiv.org/abs/1808.05689>`_ paper.
GEDs can be accessed via the global attributes :obj:`ged` and
:obj:`norm_ged` and provide the exact GEDs for all train/train graph pairs
and all train/test graph pairs.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (one of :obj:`"AIDS700nef"`,
:obj:`"LINUX"`, :obj:`"ALKANE"`, :obj:`"IMDBMulti"`).
train (bool, optional): If :obj:`True`, loads the training dataset,
otherwise the test dataset. (default: :obj:`True`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
url = 'https://drive.google.com/uc?export=download&id={}'
datasets = {
'AIDS700nef': {
'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z',
'extract': extract_zip,
'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW',
},
'LINUX': {
'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI',
'extract': extract_tar,
'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v',
},
'ALKANE': {
'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt',
'extract': extract_tar,
'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu',
},
'IMDBMulti': {
'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST',
'extract': extract_zip,
'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ',
},
}
types = [
'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
'Sb', 'Se', 'Ni', 'Te'
]
def __init__(self,
root,
name,
train=True,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
assert self.name in self.datasets.keys()
super(GEDDataset, self).__init__(root, transform, pre_transform,
pre_filter)
path = self.processed_paths[0] if train else self.processed_paths[1]
self.data, self.slices = torch.load(path)
path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
self.ged = torch.load(path)
path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
self.norm_ged = torch.load(path)
@property
def raw_file_names(self):
return [osp.join(self.name, s) for s in ['train', 'test']]
@property
def processed_file_names(self):
return ['{}_{}.pt'.format(self.name, s) for s in ['training', 'test']]
def download(self):
name = self.datasets[self.name]['id']
path = download_url(self.url.format(name), self.raw_dir)
self.datasets[self.name]['extract'](path, self.raw_dir)
os.unlink(path)
name = self.datasets[self.name]['pickle']
path = download_url(self.url.format(name), self.raw_dir)
os.rename(path, osp.join(self.raw_dir, self.name, 'ged.pickle'))
def process(self):
ids, Ns = [], []
for r_path, p_path in zip(self.raw_paths, self.processed_paths):
names = glob.glob(osp.join(r_path, '*.gexf'))
ids.append(sorted([int(i.split(os.sep)[-1][:-5]) for i in names]))
data_list = []
for i, idx in enumerate(ids[-1]):
i = i if len(ids) == 1 else i + len(ids[0])
G = nx.read_gexf(osp.join(r_path, '{}.gexf'.format(idx)))
mapping = {name: j for j, name in enumerate(G.nodes())}
G = nx.relabel_nodes(G, mapping)
Ns.append(G.number_of_nodes())
edge_index = torch.tensor(list(G.edges)).t().contiguous()
if edge_index.numel() == 0:
edge_index = torch.empty((2, 0), dtype=torch.long)
edge_index = to_undirected(edge_index, num_nodes=Ns[-1])
data = Data(edge_index=edge_index, i=i)
data.num_nodes = Ns[-1]
if self.name == 'AIDS700nef':
x = torch.zeros(data.num_nodes, dtype=torch.long)
for node, info in G.nodes(data=True):
x[int(node)] = self.types.index(info['type'])
data.x = F.one_hot(
x, num_classes=len(self.types)).to(torch.float)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), p_path)
assoc = {idx: i for i, idx in enumerate(ids[0])}
assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])})
path = osp.join(self.raw_dir, self.name, 'ged.pickle')
mat = torch.full((len(assoc), len(assoc)), float('inf'))
with open(path, 'rb') as f:
obj = pickle.load(f)
xs, ys, gs = [], [], []
for (x, y), g in obj.items():
xs += [assoc[x]]
ys += [assoc[y]]
gs += [g]
x, y, g = torch.tensor(xs), torch.tensor(ys), torch.tensor(gs)
mat[x, y], mat[y, x] = g, g
path = osp.join(self.processed_dir, '{}_ged.pt'.format(self.name))
torch.save(mat, path)
N = torch.tensor(Ns, dtype=torch.float)
norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1)))
path = osp.join(self.processed_dir, '{}_norm_ged.pt'.format(self.name))
torch.save(norm_mat, path)
def __repr__(self):
return '{}({})'.format(self.name, len(self))