Skip to content

Commit

Permalink
Implemented metric: Hinge loss
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed Mar 21, 2022
1 parent bacbd07 commit 492922c
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 2 deletions.
7 changes: 6 additions & 1 deletion miseval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
from miseval.mcc import *
# Boundary Distances
from miseval.boundary_distance import *
# Hinge loss
from miseval.hinge import *

#-----------------------------------------------------#
# Access Functions to Metric Functions #
Expand Down Expand Up @@ -112,7 +114,10 @@
"MCC_absolute": calc_MCC_Absolute,
"aMCC": calc_MCC_Absolute,
"BoundaryDistance": calc_Boundary_Distance,
"BD": calc_Boundary_Distance
"Distance": calc_Boundary_Distance,
"BD": calc_Boundary_Distance,
"Hinge": calc_Hinge,
"HingeLoss": calc_Hinge,
}

#-----------------------------------------------------#
Expand Down
2 changes: 1 addition & 1 deletion miseval/boundary_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
List of available distances:
Bhattacharyya distance bhattacharyya
Bhattacharyya coefficient bhattacharyya_coefficient
Bhattacharyya coefficient bhattacharyya_coefficient
Canberra distance canberra
Chebyshev distance chebyshev
Chi Square distance chi_square
Expand Down
51 changes: 51 additions & 0 deletions miseval/hinge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External modules
import numpy as np

#-----------------------------------------------------#
# Calculate : asd #
#-----------------------------------------------------#
""" Compute Hinge loss between truth and prediction probabilities.
In machine learning, the hinge loss is a loss function used for training classifiers.
The hinge loss is used for "maximum-margin" classification.
Pooling (how to combine computed Hinge losses to a single value):
Distance Sum sum
Distance Averaging mean
Minimum Distance amin
Maximum Distance amax
"""
def calc_Hinge(truth, pred_prob, c=1, pooling="mean"):
# Obtain binary classification
prob = pred_prob[:,:,c]
gt = np.equal(truth, c).astype(int)
# Convert ground truth 0/1 format to -1/+1 format
gt = np.where(gt==0, -1, gt)
# Compute Hinge
hinge_total = np.maximum(1 - gt * prob, 0)
# Apply pooling function across all pixel classifications
res = getattr(np, pooling)(res_dist)
# Return Hinge
return hinge
2 changes: 2 additions & 0 deletions tests/test_boundary_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_calc_BoundaryDistance_General(self):
# Check existance in metric_dict
self.assertTrue("BoundaryDistance" in metric_dict)
self.assertTrue(callable(metric_dict["BoundaryDistance"]))
self.assertTrue("Distance" in metric_dict)
self.assertTrue(callable(metric_dict["Distance"]))
self.assertTrue("BD" in metric_dict)
self.assertTrue(callable(metric_dict["BD"]))

Expand Down
61 changes: 61 additions & 0 deletions tests/test_hinge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External modules
import numpy as np
import unittest
# Internal modules
from miseval import *

#-----------------------------------------------------#
# Unittest: Area under the ROC #
#-----------------------------------------------------#
class TEST_AUC(unittest.TestCase):
@classmethod
def setUpClass(self):
# Create ground truth
np.random.seed(1)
self.gt_bi = np.random.randint(2, size=(32,32))
self.gt_mc = np.random.randint(5, size=(32,32))
# Create prediction mask
np.random.seed(2)
self.pd_bi = np.random.randint(2, size=(32,32))
self.pd_mc = np.random.randint(5, size=(32,32))
# Create prediction probability
self.prob_bi = np.random.rand(32,32,2)
self.prob_mc = np.random.rand(32,32,5)

#-------------------------------------------------#
# Calculate : Hinge #
#-------------------------------------------------#
def test_calc_Hinge(self):
# Check binary score
score_bi = calc_Hinge(self.gt_bi, self.prob_bi, c=1)
self.assertTrue(isinstance(score_bi, np.float64))
# Check multi-class score
for i in range(5):
score_mc = calc_Hinge(self.gt_mc, self.prob_mc, c=i)
self.assertTrue(isinstance(score_mc, np.float64))
# Check existance in metric_dict
self.assertTrue("Hinge" in metric_dict)
self.assertTrue(callable(metric_dict["Hinge"]))
self.assertTrue("HingeLoss" in metric_dict)
self.assertTrue(callable(metric_dict["HingeLoss"]))

0 comments on commit 492922c

Please sign in to comment.