Skip to content

Commit

Permalink
fix AUC test (allenai#4795)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 17, 2020
1 parent efde092 commit 3cad5b4
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/training/metrics/auc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_distributed_auc(self):
labels = torch.randint(3, 5, (8,), dtype=torch.long)
# We make sure that the positive label is always present.
labels[0] = 4
labels[4] = 4

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
labels.cpu().numpy(), predictions.cpu().numpy(), pos_label=4
Expand All @@ -124,6 +125,7 @@ def test_distributed_auc_unequal_batches(self):
labels = torch.randint(3, 5, (8,), dtype=torch.long)
# We make sure that the positive label is always present.
labels[0] = 4
labels[4] = 4

false_positive_rates, true_positive_rates, _ = metrics.roc_curve(
labels.cpu().numpy(), predictions.cpu().numpy(), pos_label=4
Expand Down

0 comments on commit 3cad5b4

Please sign in to comment.