Source code for torch_geometric.datasets.karate

import torch
import numpy as np
import networkx as nx
from torch_geometric.data import InMemoryDataset, Data


[docs]class KarateClub(InMemoryDataset): r"""Zachary's karate club network from the `"An Information Flow Model for Conflict and Fission in Small Groups" <http://www1.ind.ku.dk/complexLearning/zachary1977.pdf>`_ paper, containing 34 nodes, connected by 154 (undirected and unweighted) edges. Every node is labeled by one of two classes. Args: transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) """ def __init__(self, transform=None): super(KarateClub, self).__init__('.', transform, None, None) G = nx.karate_club_graph() adj = nx.to_scipy_sparse_matrix(G).tocoo() row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long) col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long) edge_index = torch.stack([row, col], dim=0) data = Data(edge_index=edge_index) data.num_nodes = edge_index.max().item() + 1 data.x = torch.eye(data.num_nodes, dtype=torch.float) y = [0 if G.nodes[i]['club'] == 'Mr. Hi' else 1 for i in G.nodes] data.y = torch.tensor(y) self.data, self.slices = self.collate([data]) def _download(self): return def _process(self): return def __repr__(self): return '{}()'.format(self.__class__.__name__)