Skip to content

Commit

Permalink
Add BA toy test
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiastraub committed Jun 10, 2024
1 parent afa1fdf commit 56b11fc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
2 changes: 0 additions & 2 deletions fd_shifts/analysis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ def __init__(
correct,
n_bins,
labels=None,
prevalence_ratios=None,
legacy=False,
) -> None:
super().__init__()
self.confids: npt.NDArray[Any] = confids
self.correct: npt.NDArray[Any] = correct
self.n_bins: int = n_bins
self.labels = labels
self.prevalence_ratios = prevalence_ratios
self.legacy = legacy

@cached_property
Expand Down
4 changes: 0 additions & 4 deletions fd_shifts/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,6 @@ def test_class_aware_metric_values(stats_cache: SC_test, expected: dict):
# Now, compare to explicit result values
np.testing.assert_almost_equal(stats_cache.aurc_ba, expected["aurc_ba"])
np.testing.assert_almost_equal(stats_cache.augrc_ba, expected["augrc_ba"])
np.testing.assert_almost_equal(
stats_cache.get_working_point(risk="generalized-risk-ba", target_cov=0.95),
expected["generalized-risk-ba@95cov"],
)


def test_achievable_rc():
Expand Down
20 changes: 15 additions & 5 deletions fd_shifts/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ class SC_test(metrics.StatsCache):

AUC_DISPLAY_SCALE = 1

def __init__(self, confids, correct):
super().__init__(confids, correct, n_bins=20, legacy=False)
def __init__(self, confids, correct, **kwargs):
super().__init__(confids, correct, n_bins=20, legacy=False, **kwargs)


class SC_scale1000_test(metrics.StatsCache):
"""Using AURC_DISPLAY_SCALE=1000 and n_bins=20."""

AUC_DISPLAY_SCALE = 1000

def __init__(self, confids, correct):
super().__init__(confids, correct, n_bins=20, legacy=False)
def __init__(self, confids, correct, **kwargs):
super().__init__(confids, correct, n_bins=20, legacy=False, **kwargs)


N_SAMPLES = 100
Expand Down Expand Up @@ -432,4 +432,14 @@ def __init__(self, confids, correct):


# Testing metrics that explicitly depend on GT labels
RC_STATS_CLASS_AWARE_TEST_CASES = {}
RC_STATS_CLASS_AWARE_TEST_CASES = {
SC_test(
confids=np.array([0, 1, 2, 3]),
correct=np.array([1, 0, 1, 0]),
labels=np.array([0, 0, 1, 1]),
): {
"ID": "toy-case",
"aurc_ba": 0.5,
"augrc_ba": 0.3125,
}
}

0 comments on commit 56b11fc

Please sign in to comment.