forked from Adamdad/CT-Lung-Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Loss.py
123 lines (102 loc) · 4.07 KB
/
Loss.py
1
import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npdef make_one_hot(input, num_classes): """Convert class index tensor to one hot encoding tensor. Args: input: A tensor of shape [N, 1, *] num_classes: An int of number of class Returns: A tensor of shape [N, num_classes, *] """ shape = np.array(input.shape) shape[1] = num_classes shape = tuple(shape) result = torch.zeros(shape) result = result.scatter_(1, input.cpu(), 1) return resultclass BinaryDiceLoss(nn.Module): """Dice loss of binary class Args: smooth: A float number to smooth loss, and avoid NaN error, default: 1 p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 predict: A tensor of shape [N, *] target: A tensor of shape same with predict reduction: Reduction method to apply, return mean over batch if 'mean', return sum if 'sum', return a tensor of shape [N,] if 'none' Returns: Loss tensor according to arg reduction Raise: Exception if unexpected reduction """ def __init__(self, smooth=1, p=2, reduction='mean'): super(BinaryDiceLoss, self).__init__() self.smooth = smooth self.p = p self.reduction = reduction def forward(self, predict, target): assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" predict = predict.contiguous().view(predict.shape[0], -1) target = target.contiguous().view(target.shape[0], -1) num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth loss = 1 - num / den if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() elif self.reduction == 'none': return loss else: raise Exception('Unexpected reduction {}'.format(self.reduction))class DiceLoss(nn.Module): """Dice loss, need one hot encode input Args: weight: An array of shape [num_classes,] ignore_index: class index to ignore predict: A tensor of shape [N, C, *] target: A tensor of same shape with predict other args pass to BinaryDiceLoss Return: same as BinaryDiceLoss """ def __init__(self, weight=None, ignore_index=None, **kwargs): super(DiceLoss, self).__init__() self.kwargs = kwargs self.weight = weight self.ignore_index = ignore_index self.num_class = 2 def forward(self, predict, target): target = make_one_hot(target.unsqueeze(1),self.num_class).cuda() assert predict.shape == target.shape, 'predict & target shape do not match' dice = BinaryDiceLoss(**self.kwargs) total_loss = 0 predict = F.softmax(predict, dim=1) for i in range(target.shape[1]): if i != self.ignore_index: dice_loss = dice(predict[:, i], target[:, i]) if self.weight is not None: assert self.weight.shape[0] == target.shape[1], \ 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) dice_loss *= self.weights[i] total_loss += dice_loss return total_loss/target.shape[1]class DiceCELoss(nn.Module): """Dice loss, need one hot encode input Args: weight: An array of shape [num_classes,] ignore_index: class index to ignore predict: A tensor of shape [N, C, *] target: A tensor of same shape with predict other args pass to BinaryDiceLoss Return: same as BinaryDiceLoss """ def __init__(self): super(DiceCELoss, self).__init__() self.weight = [1,1] self.ce_loss = nn.CrossEntropyLoss() self.dice_loss = DiceLoss() def forward(self, predict, target): return self.weight[0] * self.ce_loss(predict,target) + self.weight[1] * self.dice_loss(predict,target)