from __future__ import division
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add
[docs]def accuracy(pred, target):
r"""Computes the accuracy of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
:rtype: int
"""
return (pred == target).sum().item() / target.numel()
[docs]def true_positive(pred, target, num_classes):
r"""Computes the number of true positive predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred == i) & (target == i)).sum())
return torch.tensor(out)
[docs]def true_negative(pred, target, num_classes):
r"""Computes the number of true negative predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred != i) & (target != i)).sum())
return torch.tensor(out)
[docs]def false_positive(pred, target, num_classes):
r"""Computes the number of false positive predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred == i) & (target != i)).sum())
return torch.tensor(out)
[docs]def false_negative(pred, target, num_classes):
r"""Computes the number of false negative predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred != i) & (target == i)).sum())
return torch.tensor(out)
[docs]def precision(pred, target, num_classes):
r"""Computes the precision
:math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
tp = true_positive(pred, target, num_classes).to(torch.float)
fp = false_positive(pred, target, num_classes).to(torch.float)
out = tp / (tp + fp)
out[torch.isnan(out)] = 0
return out
[docs]def recall(pred, target, num_classes):
r"""Computes the recall
:math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
tp = true_positive(pred, target, num_classes).to(torch.float)
fn = false_negative(pred, target, num_classes).to(torch.float)
out = tp / (tp + fn)
out[torch.isnan(out)] = 0
return out
[docs]def f1_score(pred, target, num_classes):
r"""Computes the :math:`F_1` score
:math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}}
{\mathrm{precision}+\mathrm{recall}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
prec = precision(pred, target, num_classes)
rec = recall(pred, target, num_classes)
score = 2 * (prec * rec) / (prec + rec)
score[torch.isnan(score)] = 0
return score
[docs]def intersection_and_union(pred, target, num_classes, batch=None):
r"""Computes intersection and union of predictions.
Args:
pred (LongTensor): The predictions.
target (LongTensor): The targets.
num_classes (int): The number of classes.
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
:rtype: (:class:`LongTensor`, :class:`LongTensor`)
"""
pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)
if batch is None:
i = (pred & target).sum(dim=0)
u = (pred | target).sum(dim=0)
else:
i = scatter_add(pred & target, batch, dim=0)
u = scatter_add(pred | target, batch, dim=0)
return i, u
[docs]def mean_iou(pred, target, num_classes, batch=None):
r"""Computes the mean intersection over union score of predictions.
Args:
pred (LongTensor): The predictions.
target (LongTensor): The targets.
num_classes (int): The number of classes.
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
:rtype: :class:`Tensor`
"""
i, u = intersection_and_union(pred, target, num_classes, batch)
iou = i.to(torch.float) / u.to(torch.float)
iou[torch.isnan(iou)] = 1
iou = iou.mean(dim=-1)
return iou