Skip to content

Commit

Permalink
fix import and method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 314410140
  • Loading branch information
andrewluchen authored and copybara-github committed Jun 2, 2020
1 parent e5a356b commit 02892b0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cnn_quantization/tf_cnn_benchmarks/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ class ImagenetPreprocessor(RecordInputImagePreprocessor):
def preprocess(self, image_buffer, bbox, batch_position):
# pylint: disable=g-import-not-at-top
try:
from tensorflow_models.official.resnet.imagenet_preprocessing import preprocess_image
from tensorflow_models.official.r1.resnet.imagenet_preprocessing import preprocess_image
except ImportError:
tf.logging.fatal('Please include tensorflow/models to the PYTHONPATH.')
raise
Expand Down
14 changes: 8 additions & 6 deletions soft_sort/jax/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,30 @@ def test_sort(self):
s = ops.softsort(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
self.assertEqual(s.shape, self.x.shape)
deltas = np.diff(s, axis=-1) > 0
self.assertAllClose(deltas, np.ones(deltas.shape, dtype=bool), True)
self.assertAllClose(
deltas, np.ones(deltas.shape, dtype=bool), check_dtypes=True)

def test_sort_descending(self):
x = self.x[0][0]
s = ops.softsort(x, axis=-1, direction='DESCENDING',
threshold=1e-3, epsilon=1e-3)
self.assertEqual(s.shape, x.shape)
deltas = np.diff(s, axis=-1) < 0
self.assertAllClose(deltas, np.ones(deltas.shape, dtype=bool), True)
self.assertAllClose(
deltas, np.ones(deltas.shape, dtype=bool), check_dtypes=True)

def test_ranks(self):
ranks = ops.softranks(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
self.assertEqual(ranks.shape, self.x.shape)
true_ranks = np.argsort(np.argsort(self.x, axis=-1), axis=-1)
self.assertAllClose(ranks, true_ranks, False, atol=1e-3)
self.assertAllClose(ranks, true_ranks, check_dtypes=False, atol=1e-3)

def test_ranks_one_based(self):
ranks = ops.softranks(self.x, axis=-1, zero_based=False,
threshold=1e-3, epsilon=1e-3)
self.assertEqual(ranks.shape, self.x.shape)
true_ranks = np.argsort(np.argsort(self.x, axis=-1), axis=-1) + 1
self.assertAllClose(ranks, true_ranks, False, atol=1e-3)
self.assertAllClose(ranks, true_ranks, check_dtypes=False, atol=1e-3)

def test_ranks_descending(self):
ranks = ops.softranks(
Expand All @@ -75,7 +77,7 @@ def test_ranks_descending(self):

max_rank = self.x.shape[-1] - 1
true_ranks = max_rank - np.argsort(np.argsort(self.x, axis=-1), axis=-1)
self.assertAllClose(ranks, true_ranks, False, atol=1e-3)
self.assertAllClose(ranks, true_ranks, check_dtypes=False, atol=1e-3)

@parameterized.named_parameters(
('medians_-1', 0.5, -1),
Expand All @@ -95,7 +97,7 @@ def test_softquantile(self, quantile, axis):
s.pop(axis)
self.assertTupleEqual(qs.shape, tuple(s))
self.assertAllClose(
qs, np.quantile(x, quantile, axis=axis), True, rtol=1e-2)
qs, np.quantile(x, quantile, axis=axis), check_dtypes=True, rtol=1e-2)


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions soft_sort/jax/soft_quantilizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def setUp(self):
def test_sort(self):
q = soft_quantilizer.SoftQuantilizer(self.x, threshold=1e-3, epsilon=1e-3)
deltas = np.diff(q.softsort, axis=-1) > 0
self.assertAllClose(deltas, np.ones(deltas.shape, dtype=bool), True)
self.assertAllClose(
deltas, np.ones(deltas.shape, dtype=bool), check_dtypes=True)

def test_target_weights(self):
q = soft_quantilizer.SoftQuantilizer(
Expand All @@ -57,7 +58,7 @@ def test_ranks(self):
q = soft_quantilizer.SoftQuantilizer(self.x, threshold=1e-3, epsilon=1e-3)
soft_ranks = q._n * q.softcdf
true_ranks = np.argsort(np.argsort(q.x, axis=-1), axis=-1) + 1
self.assertAllClose(soft_ranks, true_ranks, False, atol=1e-3)
self.assertAllClose(soft_ranks, true_ranks, check_dtypes=False, atol=1e-3)


if __name__ == '__main__':
Expand Down

0 comments on commit 02892b0

Please sign in to comment.