Skip to content

Commit

Permalink
Tweak AUC test to avoid sharing state across calls in np.apply_along_…
Browse files Browse the repository at this point in the history
…axis.

This was causing issues when porting over to the new test decorators. The root
cause is unclear. It's possible that m.reset_states() isn't sufficiently
thorough. In any case, avoiding sharing the state altogether solves the issue.
Since np.apply_along_axis doesn't seem to gain anything by reusing the
`tf.keras.metrics.AUC` instance across calls, it seems safest to avoid reusing it.

PiperOrigin-RevId: 274245423
  • Loading branch information
csuter authored and tensorflower-gardener committed Oct 11, 2019
1 parent 0d80325 commit 2ee0219
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tensorflow_probability/python/stats/ranking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,11 @@ def testAurocAuprc(self, curve, positive_means, negative_means):
y_true = np.array([1] * num_positive_trials + [0] * num_negative_trials)
y_pred = np.concatenate([positive_trials_, negative_trials_])

m = tf.keras.metrics.AUC(num_thresholds=total_trials, curve=curve)
self.evaluate([v.initializer for v in m.variables])

def auc_fn(y_pred):
m = tf.keras.metrics.AUC(num_thresholds=total_trials, curve=curve)
self.evaluate([v.initializer for v in m.variables])
self.evaluate(m.update_state(y_true, y_pred))
out = self.evaluate(m.result())
m.reset_states()
return out

batch_shape = np.array(positive_means).shape
Expand Down

0 comments on commit 2ee0219

Please sign in to comment.