-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloss.py
32 lines (24 loc) · 899 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
import config
class NBVLoss(nn.Module):
def __init__(self, lambde_for1):
super(NBVLoss, self).__init__()
self.mse = nn.MSELoss()
self.entropy = nn.BCELoss()
self.lambda_for0 = 1
self.lambda_for1 = lambde_for1
self.lambda_l2 = 1
self.lambda_cnt1 = 1
def forward(self, predictions, target):
loss_where_1 = 0
loss_where_0 = 0
for i in range(target.shape[0]):
for j in range(target.shape[1]):
if target[i][j] == 0:
loss_where_0 += self.entropy(predictions[i][j], target[i][j]).to(config.DEVICE)
else:
loss_where_1 += self.entropy(predictions[i][j], target[i][j]).to(config.DEVICE)
return (
self.lambda_for1 * loss_where_1
+ self.lambda_for0 * loss_where_0
)