From 5a94c9c0c2771e45e278b7cc9428b55b31a8bf59 Mon Sep 17 00:00:00 2001 From: vessemer Date: Fri, 31 Aug 2018 15:09:13 +0200 Subject: [PATCH] BCEDiceJaccardLoss fixes --- albu/src/pytorch_utils/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/albu/src/pytorch_utils/loss.py b/albu/src/pytorch_utils/loss.py index d791673..93f37a5 100644 --- a/albu/src/pytorch_utils/loss.py +++ b/albu/src/pytorch_utils/loss.py @@ -99,10 +99,11 @@ def __init__(self, weights, weight=None, size_average=True): def forward(self, input, target): loss = 0 + sigmoid_input = torch.sigmoid(input) for k, v in self.weights.items(): - if not v: + if not v: continue - val = self.mapping[k](input, target) + val = self.mapping[k](input if k == 'bce' else sigmoid_input, target) self.values[k] = val if k != 'bce': loss += self.weights[k] * (1 - val)