import torch
from torch_sparse import coalesce
from .num_nodes import maybe_num_nodes
def filter_adj(row, col, edge_attr, mask):
return row[mask], col[mask], None if edge_attr is None else edge_attr[mask]
[docs]def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,
num_nodes=None, training=True):
r"""Randomly drops edges from the adjacency matrix
:obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from
a Bernoulli distribution.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
p (float, optional): Dropout probability. (default: :obj:`0.5`)
force_undirected (bool, optional): If set to :obj:`True`, will either
drop or keep both edges of an undirected edge.
(default: :obj:`False`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
training (bool, optional): If set to :obj:`False`, this operation is a
no-op. (default: :obj:`True`)
"""
if p < 0. or p > 1.:
raise ValueError('Dropout probability has to be between 0 and 1, '
'but got {}'.format(p))
if not training:
return edge_index, edge_attr
N = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
if force_undirected:
row, col, edge_attr = filter_adj(row, col, edge_attr, row < col)
mask = edge_index.new_full((row.size(0), ), 1 - p, dtype=torch.float)
mask = torch.bernoulli(mask).to(torch.bool)
row, col, edge_attr = filter_adj(row, col, edge_attr, mask)
if force_undirected:
edge_index = torch.stack(
[torch.cat([row, col], dim=0),
torch.cat([col, row], dim=0)], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
else:
edge_index = torch.stack([row, col], dim=0)
return edge_index, edge_attr