Skip to content

Commit

Permalink
Enable dtype_util.size(...) to work in the numpy / jax backend (allow…
Browse files Browse the repository at this point in the history
…ing Categorical with float64 parameters).

PiperOrigin-RevId: 275943193
  • Loading branch information
srvasude authored and tensorflower-gardener committed Oct 21, 2019
1 parent 748e66b commit fe80271
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tensorflow_probability/python/internal/dtype_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def name(dtype):
def size(dtype):
"""Returns the number of bytes to represent this `dtype`."""
dtype = tf.as_dtype(dtype)
if hasattr(dtype, 'size'):
if hasattr(dtype, 'size') and hasattr(dtype, 'as_numpy_dtype'):
return dtype.size
return np.dtype(dtype).itemsize

Expand Down
12 changes: 11 additions & 1 deletion tensorflow_probability/python/internal/dtype_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,17 @@ def test_assert_same_float_dtype(self):
self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
[const_int])

def test_size(self):
self.assertEqual(dtype_util.size(tf.int32), 4)
self.assertEqual(dtype_util.size(tf.int64), 8)
self.assertEqual(dtype_util.size(tf.float32), 4)
self.assertEqual(dtype_util.size(tf.float64), 8)

self.assertEqual(dtype_util.size(np.int32), 4)
self.assertEqual(dtype_util.size(np.int64), 8)
self.assertEqual(dtype_util.size(np.float32), 4)
self.assertEqual(dtype_util.size(np.float64), 8)


if __name__ == '__main__':
tf.test.main()

0 comments on commit fe80271

Please sign in to comment.