Skip to content

Commit

Permalink
Implemented metric: Cross-Entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed Mar 21, 2022
1 parent 7928c5d commit 9b8680b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
56 changes: 56 additions & 0 deletions miseval/entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External modules
import numpy as np

#-----------------------------------------------------#
# Calculate : Cross-Entropy #
#-----------------------------------------------------#
""" Compute cross-entropy between truth and prediction probabilities.
In information theory, the cross-entropy between two probability distributions p and q
over the same underlying set of events measures the average number of bits needed to
identify an event drawn from the set if a coding scheme used for the set is optimized
for an estimated probability distribution q, rather than the true distribution p.
Source: https://en.wikipedia.org/wiki/Cross_entropy
Pooling (how to combine computed cross-entropy to a single value):
Distance Sum sum
Distance Averaging mean
Minimum Distance amin
Maximum Distance amax
"""
def calc_CrossEntropy(truth, pred_prob, c=1, pooling="mean", provided_prob=True,
**kwargs):
# Obtain binary classification
if provided_prob : prob = np.take(pred_prob, c, axis=-1)
else : prob = np.equal(pred_prob, c)
gt = np.equal(truth, c).astype(int)
# Add epsilon to probability to avoid zero divisions for log()
prob = prob + np.finfo(np.float32).eps
# Compute cross-entropy
cross_entropy = - gt * np.log(prob)
# Apply pooling function across all pixel classifications
res = getattr(np, pooling)(cross_entropy)
# Return Cross-Entropy
return res
63 changes: 63 additions & 0 deletions tests/test_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External modules
import numpy as np
import unittest
# Internal modules
from miseval import *

#-----------------------------------------------------#
# Unittest: Entropy-based metrics #
#-----------------------------------------------------#
class TEST_Entropy(unittest.TestCase):
@classmethod
def setUpClass(self):
# Create ground truth
np.random.seed(1)
self.gt_bi = np.random.randint(2, size=(32,32))
self.gt_mc = np.random.randint(5, size=(32,32))
# Create prediction mask
np.random.seed(2)
self.pd_bi = np.random.randint(2, size=(32,32))
self.pd_mc = np.random.randint(5, size=(32,32))
# Create prediction probability
self.prob_bi = np.random.rand(32,32,2)
self.prob_mc = np.random.rand(32,32,5)

#-------------------------------------------------#
# Calculate : Cross-Entropy #
#-------------------------------------------------#
def test_calc_CrossEntropy(self):
# Check binary score
score_bi = calc_CrossEntropy(self.gt_bi, self.prob_bi, c=1,
provided_prob=True)
self.assertTrue(isinstance(score_bi, np.float64))
# Check multi-class score
for i in range(5):
score_mc = calc_CrossEntropy(self.gt_mc, self.prob_mc, c=i,
provided_prob=True)
self.assertTrue(isinstance(score_mc, np.float64))
# Check existance in metric_dict
self.assertTrue("CE" in metric_dict)
self.assertTrue(callable(metric_dict["CE"]))
self.assertTrue("CrossEntropy" in metric_dict)
self.assertTrue(callable(metric_dict["CrossEntropy"]))

0 comments on commit 9b8680b

Please sign in to comment.