Advanced Mini-Batching¶
The creation of mini-batching 对于让深度学习模型的训练能够扩展到 huge amounts of data 至关重要。 Instead of processing examples one-by-one, a mini-batch groups a set of examples into a unified representation where it can efficiently be processed in parallel. 在图像或语言域中, this procedure is typically achieved by rescaling or padding each example into a set to equally-sized shapes, and examples are then grouped in an additional dimension.
The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the batch_size
.
Since graphs are one of the most general data structures that can hold any number of nodes or edges, 上述两种方法要么不可行,要么可能导致大量不必要的内存消耗。在PyTorch Geometric中,我们选择了另一种方法来实现 parallelization across a number of examples. 在这里,邻接矩阵以对角线形式堆叠 (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension, i.e.
与其他批处理过程相比,此过程具有一些关键优势:
GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.
There is no computational or memory overhead. For example, this batching procedure works completely without any padding of node or edge features. Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, i.e., the edges.
PyTorch Geometric automatically takes care of batching multiple graphs into a single giant graph with the help of the torch_geometric.data.DataLoader
class.
Internally, torch_geometric.data.DataLoader
is just a regular PyTorch DataLoader
that overwrites its collate()
functionality, i.e., the definition of how a list of examples should be grouped together.
Therefore, all arguments that can be passed to a PyTorch DataLoader
can also be passed to a PyTorch Geometric DataLoader
, e.g., the number of workers num_workers
.
In its most general form, the PyTorch Geometric DataLoader
will automatically increment the edge_index
tensor by the cumulated number of nodes of graphs that got collated before the currently processed graph, and will concatenate edge_index
tensors (that are of shape [2, num_edges]
) in the second dimension.
The same is true for face
tensors.
All other tensors will just get concatenated in the first dimension without any further increasement of their values.
However, there are a few special use-cases (as outlined below) where the user actively wants to modify this behaviour to its own needs.
PyTorch Geometric allows modification to the underlying batching procedure by overwriting the torch_geometric.data.Data.__inc__()
and torch_geometric.data.Data.__cat_dim__()
functionalities.
Without any modifications, these are defined as follows in the torch_geometric.data.Data
class:
def __inc__(self, key, value):
if 'index' in key or 'face' in key:
return self.num_nodes
else:
return 0
def __cat_dim__(self, key, value):
if 'index' in key or 'face' in key:
return 1
else:
return 0
We can see that __inc__()
defines the incremental count between two consecutive graph attributes, where as __cat_dim__()
defines in which dimension graph tensors of the same attribute should be concatenated together.
Both functions are called for each attribute stored in the torch_geometric.data.Data
class, and get passed their specific key
and value item
as arguments.
In what follows, we present a few use-cases where the modification of __inc__()
and __cat_dim__()
might be absolutely necessary.
Pairs of Graphs¶
In case you want to store multiple graphs in a single torch_geometric.data.Data
object, e.g., for applications such as graph matching, you need to ensure correct batching behaviour across all those graphs.
For example, consider storing two graphs, a source graph \(\mathcal{G}_s\) and a target graph \(\mathcal{G}_t\) in a torch_geometric.data.Data
, e.g.:
class PairData(Data):
def __init__(self, edge_index_s, x_s, edge_index_t, x_t):
super(PairData, self).__init__()
self.edge_index_s = edge_index_s
self.x_s = x_s
self.edge_index_t = edge_index_t
self.x_t = x_t
In this case, edge_index_s
should be increased by the number of nodes in the source graph \(\mathcal{G}_s\), e.g., x_s.size(0)
, and edge_index_t
should be increased by the number of nodes in the target graph \(\mathcal{G}_t\), e.g., x_t.size(0)
:
def __inc__(self, key, value):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
else:
return super(PairData, self).__inc__(key, value)
We can test our PairData
batching behaviour by setting up a simple test script:
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
x_t = torch.randn(4, 16) # 4 nodes.
data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> Batch(edge_index_s=[2, 8], x_s=[10, 16],
edge_index_t=[2, 6], x_t=[8, 16])
print(batch.edge_index_s)
>>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
[1, 2, 3, 4, 6, 7, 8, 9]])
print(batch.edge_index_t)
>>> tensor([[0, 0, 0, 4, 4, 4],
[1, 2, 3, 5, 6, 7]])
Everything looks good so far!
edge_index_s
and edge_index_t
get correctly batched together, even when using different numbers of nodes for \(\mathcal{G}_s\) and \(\mathcal{G}_t\).
However, the batch
attribute (that maps each node to its respective graph) is missing since PyTorch Geometric fails to identify the actual graph in the PairData
object.
That’s where the follow_batch
argument of the torch_geometric.data.DataLoader
comes into play.
Here, we can specify for which attributes we want to maintain the batch information:
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))
print(batch)
>>> Batch(edge_index_s=[2, 8], x_s=[10, 16], x_s_batch=[10],
edge_index_t=[2, 6], x_t=[8, 16], x_t_batch=[8])
print(batch.x_s_batch)
>>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print(batch.x_t_batch)
>>> tensor([0, 0, 0, 0, 1, 1, 1, 1])
As one can see, follow_batch=['x_s', 'x_t']
now successfully creates assignment vectors called x_s_batch
and x_t_batch
for the node features x_s
and x_t
, respectively.
That information can now be used to perform reduce operations, e.g., global pooling, on multiple graphs in a single Batch
object.
Bipartite Graphs¶
The adjacency matrix of a bipartite graph defines the relationship between nodes of two different node types.
In general, the number of nodes for each node type do not need to match, resulting in a non-quadratic adjacency matrix of shape \(\mathbf{A} \in \{ 0, 1 \}^{N \times M}\) with \(N \neq M\) potentially.
In a mini-batching procedure of bipartite graphs, the source nodes of edges in edge_index
should get increased differently than the target nodes of edges in edge_index
.
To achieve this, consider a bipartite graph between two node types with corresponding node features x_s
and x_t
, respectively:
class BipartiteData(Data):
def __init__(self, edge_index, x_s, x_t):
super(BipartiteData, self).__init__()
self.edge_index = edge_index
self.x_s = x_s
self.x_t = x_t
For a correct mini-batching procedure in bipartite graphs, we need to tell PyTorch Geometric that it should increment source and target nodes of edges in edge_index
independently on each other:
def __inc__(self, key, value):
if key == 'edge_index':
return torch.tensor([self.x_s.size(0), self.x_t.size(0)])
else:
return super(BipartiteData, self).__inc__(key, value)
Here, edge_index[0]
(the source nodes of edges) get incremented by x_s.size(0)
while edge_index[1]
(the target nodes of edges) get incremented by x_t.size(0)
.
We can again test our implementation by running a simple test script:
edge_index = torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 2],
])
x_s = torch.randn(2, 16) # 2 nodes.
x_t = torch.randn(3, 16) # 3 nodes.
data = BipartiteData(edge_index, x_s, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> Batch(edge_index=[2, 8], x_s=[4, 16], x_t=[6, 16])
print(batch.edge_index)
>>> tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 1, 1, 2, 3, 4, 4, 5]])
Again, this is exactly the behaviour we aimed for!