forked from serizard/ViTPose_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmse_loss.py
151 lines (126 loc) · 5.77 KB
/
mse_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
__all__ = ['JointsMSELoss', 'JointsOHKMMSELoss',]
class JointsMSELoss(nn.Module):
"""MSE loss for heatmaps.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss()
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0.
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
loss += self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx])
else:
loss += self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints * self.loss_weight
class CombinedTargetMSELoss(nn.Module):
"""MSE loss for combined target.
CombinedTarget: The combination of classification target
(response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_channels = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_channels, -1)).split(1, 1)
heatmaps_gt = target.reshape(
(batch_size, num_channels, -1)).split(1, 1)
loss = 0.
num_joints = num_channels // 3
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx * 3].squeeze()
heatmap_gt = heatmaps_gt[idx * 3].squeeze()
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
if self.use_target_weight:
heatmap_pred = heatmap_pred * target_weight[:, idx]
heatmap_gt = heatmap_gt * target_weight[:, idx]
# classification loss
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
# regression loss
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
heatmap_gt * offset_x_gt)
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
heatmap_gt * offset_y_gt)
return loss / num_joints * self.loss_weight
class JointsOHKMMSELoss(nn.Module):
"""MSE loss with online hard keypoint mining.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
topk (int): Only top k joint losses are kept.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, topk=8, loss_weight=1.):
super().__init__()
assert topk > 0
self.criterion = nn.MSELoss(reduction='none')
self.use_target_weight = use_target_weight
self.topk = topk
self.loss_weight = loss_weight
def _ohkm(self, loss):
"""Online hard keypoint mining."""
ohkm_loss = 0.
N = len(loss)
for i in range(N):
sub_loss = loss[i]
_, topk_idx = torch.topk(
sub_loss, k=self.topk, dim=0, sorted=False)
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
ohkm_loss += torch.sum(tmp_loss) / self.topk
ohkm_loss /= N
return ohkm_loss
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
if num_joints < self.topk:
raise ValueError(f'topk ({self.topk}) should not '
f'larger than num_joints ({num_joints}).')
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
losses = []
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
losses.append(
self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx]))
else:
losses.append(self.criterion(heatmap_pred, heatmap_gt))
losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses]
losses = torch.cat(losses, dim=1)
return self._ohkm(losses) * self.loss_weight