-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
94 lines (75 loc) · 3.18 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Focal Loss Implementation from https://github.com/AdeelH/pytorch-multi-class-focal-loss
from typing import Optional, Sequence
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class FocalLoss(nn.Module):
""" Focal Loss, as described in https://arxiv.org/abs/1708.02002.
It is essentially an enhancement to cross entropy loss and is
useful for classification tasks when there is a large class imbalance.
x is expected to contain raw, unnormalized scores for each class.
y is expected to contain class labels.
Shape:
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
"""
def __init__(self,
alpha: Optional[Tensor] = None,
gamma: float = 0.,
reduction: str = 'mean',
ignore_index: int = -100):
"""Constructor.
Args:
alpha (Tensor, optional): Weights for each class. Defaults to None.
gamma (float, optional): A constant, as described in the paper.
Defaults to 0.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
Defaults to -100.
"""
if reduction not in ('mean', 'sum', 'none'):
raise ValueError(
'Reduction must be one of: "mean", "sum", "none".')
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.reduction = reduction
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction='none', ignore_index=ignore_index)
def __repr__(self):
arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
arg_vals = [self.__dict__[k] for k in arg_keys]
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
arg_str = ', '.join(arg_strs)
return f'{type(self).__name__}({arg_str})'
def forward(self, x: Tensor, y: Tensor) -> Tensor:
if x.ndim > 2:
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
c = x.shape[1]
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
y = y.view(-1)
unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return torch.tensor(0.)
x = x[unignored_mask]
# compute weighted cross entropy term: -alpha * log(pt)
# (alpha is already part of self.nll_loss)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
# get true class column from each row
all_rows = torch.arange(len(x))
log_pt = log_p[all_rows, y]
# compute focal term: (1 - pt)^gamma
pt = log_pt.exp()
focal_term = (1 - pt)**self.gamma
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss