Source code for torch_geometric.utils.loop

import torch

from .num_nodes import maybe_num_nodes


[docs]def contains_self_loops(edge_index): r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains self-loops. Args: edge_index (LongTensor): The edge indices. :rtype: bool """ row, col = edge_index mask = row == col return mask.sum().item() > 0
[docs]def remove_self_loops(edge_index, edge_attr=None): r"""Removes every self-loop in the graph given by :attr:`edge_index`, so that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ row, col = edge_index mask = row != col edge_attr = edge_attr if edge_attr is None else edge_attr[mask] edge_index = edge_index[:, mask] return edge_index, edge_attr
[docs]def segregate_self_loops(edge_index, edge_attr=None): r"""Segregates self-loops from the graph. Args: edge_index (LongTensor): The edge indices. edge_attr (Tensor, optional): Edge weights or multi-dimensional edge features. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`, :class:`Tensor`) """ mask = edge_index[0] != edge_index[1] inv_mask = ~mask loop_edge_index = edge_index[:, inv_mask] loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask] edge_index = edge_index[:, mask] edge_attr = None if edge_attr is None else edge_attr[mask] return edge_index, edge_attr, loop_edge_index, loop_edge_attr
[docs]def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted, self-loops will be added with edge weights denoted by :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) fill_value (int, optional): If :obj:`edge_weight` is not :obj:`None`, will add self-loops with edge weights of :obj:`fill_value` to the graph. (default: :obj:`1`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) if edge_weight is not None: assert edge_weight.numel() == edge_index.size(1) loop_weight = edge_weight.new_full((num_nodes, ), fill_value) edge_weight = torch.cat([edge_weight, loop_weight], dim=0) edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index, edge_weight
[docs]def add_remaining_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted and already contains a few self-loops, only non-existent self-loops will be added with edge weights denoted by :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) fill_value (int, optional): If :obj:`edge_weight` is not :obj:`None`, will add self-loops with edge weights of :obj:`fill_value` to the graph. (default: :obj:`1`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index mask = row != col if edge_weight is not None: assert edge_weight.numel() == edge_index.size(1) inv_mask = ~mask loop_weight = torch.full( (num_nodes, ), fill_value, dtype=None if edge_weight is None else edge_weight.dtype, device=edge_index.device) remaining_edge_weight = edge_weight[inv_mask] if remaining_edge_weight.numel() > 0: loop_weight[row[inv_mask]] = remaining_edge_weight edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0) loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1) return edge_index, edge_weight