diff --git a/fd_shifts/analysis/metrics.py b/fd_shifts/analysis/metrics.py index 8880fde..0641c21 100644 --- a/fd_shifts/analysis/metrics.py +++ b/fd_shifts/analysis/metrics.py @@ -56,7 +56,6 @@ def __init__( correct, n_bins, labels=None, - prevalence_ratios=None, legacy=False, ) -> None: super().__init__() @@ -64,7 +63,6 @@ def __init__( self.correct: npt.NDArray[Any] = correct self.n_bins: int = n_bins self.labels = labels - self.prevalence_ratios = prevalence_ratios self.legacy = legacy @cached_property diff --git a/fd_shifts/tests/test_analysis.py b/fd_shifts/tests/test_analysis.py index 4c10f5c..8b23989 100644 --- a/fd_shifts/tests/test_analysis.py +++ b/fd_shifts/tests/test_analysis.py @@ -719,10 +719,6 @@ def test_class_aware_metric_values(stats_cache: SC_test, expected: dict): # Now, compare to explicit result values np.testing.assert_almost_equal(stats_cache.aurc_ba, expected["aurc_ba"]) np.testing.assert_almost_equal(stats_cache.augrc_ba, expected["augrc_ba"]) - np.testing.assert_almost_equal( - stats_cache.get_working_point(risk="generalized-risk-ba", target_cov=0.95), - expected["generalized-risk-ba@95cov"], - ) def test_achievable_rc(): diff --git a/fd_shifts/tests/utils.py b/fd_shifts/tests/utils.py index 4ce138e..2143148 100644 --- a/fd_shifts/tests/utils.py +++ b/fd_shifts/tests/utils.py @@ -21,8 +21,8 @@ class SC_test(metrics.StatsCache): AUC_DISPLAY_SCALE = 1 - def __init__(self, confids, correct): - super().__init__(confids, correct, n_bins=20, legacy=False) + def __init__(self, confids, correct, **kwargs): + super().__init__(confids, correct, n_bins=20, legacy=False, **kwargs) class SC_scale1000_test(metrics.StatsCache): @@ -30,8 +30,8 @@ class SC_scale1000_test(metrics.StatsCache): AUC_DISPLAY_SCALE = 1000 - def __init__(self, confids, correct): - super().__init__(confids, correct, n_bins=20, legacy=False) + def __init__(self, confids, correct, **kwargs): + super().__init__(confids, correct, n_bins=20, legacy=False, **kwargs) N_SAMPLES = 100 @@ -432,4 +432,14 @@ def __init__(self, confids, correct): # Testing metrics that explicitly depend on GT labels -RC_STATS_CLASS_AWARE_TEST_CASES = {} +RC_STATS_CLASS_AWARE_TEST_CASES = { + SC_test( + confids=np.array([0, 1, 2, 3]), + correct=np.array([1, 0, 1, 0]), + labels=np.array([0, 0, 1, 1]), + ): { + "ID": "toy-case", + "aurc_ba": 0.5, + "augrc_ba": 0.3125, + } +}