Skip to content

Commit

Permalink
Rewrite nptf.convert_to_tensor to handle more cases and use a registry.
Browse files Browse the repository at this point in the history
Main changes:
- Add a registry system for tensor conversion. This enables converting objects
  like TensorShapes and Dimensions without making the convert_to_tensor code
  bloated and confusing.
- Redo the dtype logic, closely following Tensorflow's dtype logic found here:
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/lib/core/py_seq_tensor.cc
  The logic mostly lives in a _default_convert_to_tensor function that handles
  most Python types, like bool, int, float, complex, and list, tuple.
- Redo the tests to be parameterized and add additional tests that cover the
  various cases.
- Register TensorShape and Dimension with convert_to_tensor

An important consideration is how to handle convert_to_tensor when a "tensor" is
passed in itself. In TF land, when we convert a tf.Tensor to a Tensor, the
dtype must remain compatible. This restriction is extended to JAX, where if we
try converting a DeviceArray to tensor, the dtype must remain compatible.
On the other hand, for the NumPy backend, TFP code relies on more flexible
dtype conversion for NumPy code, so we relax that restriction.

Minor changes:
- Copy dtype conversion logic from tf.range to nptf.range.
- Adjust rewrite system to keep `import numpy as np` as NumPy in rewritten TFP
  code. This means that logically in both TF and JAX, np.ndarrays are the same,
  which is necessary for the dtype conversion logic to be consistent across
  backends.
- Properly import Dimension for nptf.compat.v1.

PiperOrigin-RevId: 312175371
  • Loading branch information
sharadmv authored and tensorflower-gardener committed May 18, 2020
1 parent 640f6e8 commit 1b3e111
Show file tree
Hide file tree
Showing 5 changed files with 472 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,6 @@ def main(argv):
contents = contents.replace('SKIP_DTYPE_CHECKS = True',
'SKIP_DTYPE_CHECKS = False')
is_test = lambda x: x.endswith('_test.py') or x.endswith('_test_util.py')
if not is_test(argv[1]): # We leave tests with original np.
contents = contents.replace(
'\nimport numpy as np',
'\nimport numpy as onp\nimport jax.numpy as np')
contents = contents.replace('np.bool', 'onp.bool')
contents = contents.replace('np.dtype', 'onp.dtype')
contents = contents.replace('np.euler_gamma', 'onp.euler_gamma')
contents = contents.replace('np.generic', 'onp.generic')
contents = contents.replace('np.nextafter', 'onp.nextafter')
contents = contents.replace('np.object', 'onp.object')
contents = contents.replace('np.unique', 'onp.unique')

contents = contents.replace('np.polynomial', 'onp.polynomial')
if is_test(argv[1]): # Test-only rewrites.
contents = contents.replace(
'tf.test.main()',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,20 @@ def _pad( # pylint: disable=unused-argument


def _range(start, limit=None, delta=1, dtype=None, name='range'): # pylint: disable=unused-argument
dtype = utils.numpy_dtype(dtype or utils.common_dtype([start], np.int32))
"""Emulates tf.range."""
# Emulating dtype inference logic from tf.range
dtype = utils.numpy_dtype(dtype)
start = ops.convert_to_tensor(start, dtype=dtype)
limit = None if limit is None else ops.convert_to_tensor(limit, dtype=dtype)
delta = ops.convert_to_tensor(delta, dtype=dtype)
return np.arange(start, limit, delta).astype(dtype)
if dtype is None:
dtype_hierarchy = [np.int32, np.int64, np.float32, np.float64]
inferred_dtype = max([arg.dtype for arg in [start, limit, delta]
if arg is not None],
key=dtype_hierarchy.index)
else:
inferred_dtype = dtype
return np.arange(start, limit, delta).astype(inferred_dtype)


def _reverse(tensor, axis, name=None): # pylint: disable=unused-argument
Expand Down
310 changes: 239 additions & 71 deletions tensorflow_probability/python/internal/backend/numpy/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,83 +991,251 @@ def _maybe_convert_to_tensors(args):
args)


CONVERT_TO_TENSOR_TESTS = [
# bool tests
dict(testcase_name='bool',
value=True, out_dtype=nptf.bool),
dict(testcase_name='bool_with_int32_dtype',
value=True, out_dtype=nptf.int32, dtype=nptf.int32),
dict(testcase_name='bool_with_int64_dtype',
value=True, out_dtype=nptf.int64, dtype=nptf.int64),
dict(testcase_name='bool_with_float32_dtype',
value=True, out_dtype=nptf.float32, dtype=nptf.float32),
dict(testcase_name='bool_with_float64_dtype',
value=True, out_dtype=nptf.float64, dtype=nptf.float64),
dict(testcase_name='bool_with_complex64_dtype_should_error',
value=True, dtype=nptf.complex64, error=TypeError),
dict(testcase_name='bool_with_complex64_hint',
value=True, out_dtype=nptf.bool, dtype_hint=nptf.complex64),
# int tests
dict(testcase_name='int',
value=1, out_dtype=nptf.int32),
dict(testcase_name='int_with_float32_dtype',
value=1, out_dtype=nptf.float32, dtype=nptf.float32),
# int can be cast into other types
dict(testcase_name='int_with_float32_hint',
value=1, out_dtype=nptf.float32, dtype_hint=nptf.float32),
dict(testcase_name='int64',
value=2 ** 63 - 1, out_dtype=nptf.int64),
dict(testcase_name='int64_to_int32_should_underflow',
value=2 ** 63 - 1, dtype=np.int32, out_dtype=nptf.int32, out_value=-1),
dict(testcase_name='int_with_complex64_dtype',
value=1, out_dtype=nptf.complex64, dtype=nptf.complex64),
dict(testcase_name='int_with_complex64_hint',
value=1, out_dtype=nptf.complex64, dtype_hint=nptf.complex64),
# float tests
dict(testcase_name='float',
value=1., out_dtype=nptf.float32),
dict(testcase_name='float_with_float64_dtype',
value=1., out_dtype=nptf.float64, dtype=nptf.float64),
# float can be cast into complex types but not int types
dict(testcase_name='float_with_complex64_dtype',
value=1., out_dtype=nptf.complex64, dtype=nptf.complex64),
dict(testcase_name='float_with_complex64_dtype_hint',
value=1., out_dtype=nptf.complex64, dtype_hint=nptf.complex64),
dict(testcase_name='float_with_complex128_dtype',
value=1., out_dtype=nptf.complex128, dtype=nptf.complex128),
dict(testcase_name='float_to_bool_dtype_should_error',
value=1., dtype=nptf.bool, error=TypeError),
dict(testcase_name='float_to_int32_dtype_should_error',
value=1., dtype=nptf.int32, error=TypeError),
dict(testcase_name='float_to_int32_dtype_hint',
value=1., out_dtype=nptf.float32, dtype_hint=nptf.int32),
dict(testcase_name='float_to_int64_dtype_should_error',
value=1., dtype=nptf.int32, error=TypeError),
dict(testcase_name='float_with_int32_hint',
value=1., out_dtype=nptf.float32, dtype_hint=nptf.int32),
# complex can be cast into complex types but not other types
dict(testcase_name='complex',
value=1 + 0j, out_dtype=nptf.complex128),
dict(testcase_name='complex_with_complex64_dtype',
value=1 + 0j, out_dtype=nptf.complex64, dtype=nptf.complex64),
dict(testcase_name='complex_with_bool_dtype_should_error',
value=1 + 0j, dtype=nptf.bool, error=TypeError),
dict(testcase_name='complex_with_bool_hint_should_error',
value=1 + 0j, out_dtype=nptf.complex128, dtype_hint=nptf.bool),
dict(testcase_name='complex_with_float32_dtype_should_error',
value=1 + 0j, dtype=nptf.float32, error=TypeError),
dict(testcase_name='complex_with_float32',
value=1 + 0j, out_dtype=nptf.complex128, dtype_hint=nptf.float32),
dict(testcase_name='complex_with_int32_dtype_should_error',
value=1 + 0j, dtype=nptf.int32, error=TypeError),
dict(testcase_name='complex_with_int32_hint',
value=1 + 0j, out_dtype=nptf.complex128, dtype_hint=nptf.int32),
# Empty iterables should be float32 by default
dict(testcase_name='empty_list',
value=[], out_dtype=nptf.float32),
dict(testcase_name='empty_list_with_float64_dtype',
value=[], out_dtype=nptf.float64, dtype=nptf.float64),
dict(testcase_name='empty_list_with_int32_hint',
value=[], out_dtype=nptf.int32, dtype_hint=nptf.int32),
dict(testcase_name='empty_tuple',
value=(), out_dtype=nptf.float32),
dict(testcase_name='empty_tuple_with_float64_dtype',
value=(), out_dtype=nptf.float64, dtype=nptf.float64),
dict(testcase_name='empty_tuple_with_int32_hint',
value=(), out_dtype=nptf.int32, dtype_hint=nptf.int32),
# Iterables with contents should use dtypes of contents
dict(testcase_name='list_of_ints',
value=[1], out_dtype=nptf.int32),
dict(testcase_name='nested_list_of_ints',
value=[[1]], out_dtype=nptf.int32),
dict(testcase_name='nested_list_of_bools',
value=[[True]], out_dtype=nptf.bool),
dict(testcase_name='nested_list_of_floats',
value=[[1.]], out_dtype=nptf.float32),
dict(testcase_name='list_of_ints_with_int32_dtype',
value=[1], out_dtype=nptf.int32, dtype=nptf.int32),
dict(testcase_name='list_of_ints_with_int32_hint',
value=[1], out_dtype=nptf.int32, dtype_hint=nptf.int32),
dict(testcase_name='list_of_ints_with_float32_dtype',
value=[1], out_dtype=nptf.float32, dtype=nptf.float32),
dict(testcase_name='list_of_ints_with_float32_hint',
value=[1], out_dtype=nptf.float32, dtype_hint=nptf.float32),
dict(testcase_name='list_of_ints_with_complex128_dtype',
value=[1], out_dtype=nptf.complex128, dtype=nptf.complex128),
dict(testcase_name='list_of_ints_with_complex128_hint',
value=[1], out_dtype=nptf.complex128, dtype_hint=nptf.complex128),
dict(testcase_name='list_of_floats',
value=[1.], out_dtype=nptf.float32),
dict(testcase_name='list_of_floats_with_int32_dtype_should_error',
value=[1.], dtype=nptf.int32, error=TypeError),
dict(testcase_name='list_of_floats_with_int32_hint',
value=[1.], out_dtype=nptf.float32, dtype_hint=nptf.int32),
dict(testcase_name='list_of_int_bool',
value=[1, True], out_dtype=nptf.int32),
dict(testcase_name='list_of_bool_int_should_error',
value=[True, 1], error=ValueError),
dict(testcase_name='list_of_int_bool_with_int32_dtype',
value=[1, True], dtype=nptf.int32, out_dtype=nptf.int32),
dict(testcase_name='list_of_int_bool_with_bool_dtype_should_error',
value=[1, True], dtype=nptf.bool, error=TypeError),
dict(testcase_name='list_of_int_float',
value=[1, 2.], out_dtype=nptf.float32),
dict(testcase_name='list_of_int_float_with_int32_dtype_should_error',
value=[1, 2.], dtype=nptf.int32, error=TypeError),
dict(testcase_name='list_of_int_float_with_int32_hint',
value=[1, 2.], out_dtype=nptf.float32, dtype_hint=nptf.int32),
dict(testcase_name='list_of_float_int_with_int32_dtype_should_error',
value=[1., 2], dtype=nptf.int32, error=TypeError),
dict(testcase_name='list_of_float_int_with_int32_hint',
value=[1., 2], out_dtype=nptf.float32, dtype_hint=nptf.int32),
# List of complex is more strict than list float and int
dict(testcase_name='list_of_complex_and_bool_should_error',
value=[1 + 2j, True], error=ValueError),
dict(testcase_name='list_of_bool_and_complex_should_error',
value=[True, 1 + 2j], error=ValueError),
dict(testcase_name='list_of_complex_and_float_should_error',
value=[1 + 2j, 1.], error=ValueError),
dict(testcase_name='list_of_float_and_complex_should_error',
value=[1., 1 + 2j], error=ValueError),
dict(testcase_name='list_of_complex_and_int_should_error',
value=[1 + 2j, 1], error=ValueError),
dict(testcase_name='list_of_int_and_complex_should_error',
value=[1, 1 + 2j], error=ValueError),
# Convert tensors to tensors
dict(testcase_name='int32_tensor',
value=1, in_dtype=nptf.int32, out_dtype=nptf.int32),
dict(testcase_name='int32_tensor_with_int32_dtype',
value=1, in_dtype=nptf.int32, dtype=nptf.int32, out_dtype=nptf.int32),
dict(testcase_name='int32_tensor_with_int64_hint',
value=1, in_dtype=nptf.int32, dtype_hint=nptf.int32,
out_dtype=nptf.int32),
dict(testcase_name='int32_tensor_with_float64_hint',
value=1, in_dtype=nptf.int32, dtype_hint=nptf.int32,
out_dtype=nptf.int32),
# Convert registered objects
dict(testcase_name='dimension',
value=nptf.compat.v1.Dimension(1), out_dtype=nptf.int32),
dict(testcase_name='dimension_with_int64_dtype',
value=nptf.compat.v1.Dimension(1), dtype=nptf.int64,
out_dtype=nptf.int64),
dict(testcase_name='dimension_with_float32_dtype_should_error',
value=nptf.compat.v1.Dimension(1), dtype=nptf.float32,
error=TypeError),
dict(testcase_name='dimension_with_float32_hint',
value=nptf.compat.v1.Dimension(1), dtype_hint=nptf.float32,
out_dtype=nptf.int32),
dict(testcase_name='empty_tensorshape',
value=nptf.TensorShape([]), out_dtype=nptf.int32),
dict(testcase_name='empty_tensorshape_with_float32_dtype_should_error',
value=nptf.TensorShape([]), dtype=nptf.float32, error=TypeError),
dict(testcase_name='tensorshape',
value=nptf.TensorShape((1, 2)), out_dtype=nptf.int32),
dict(testcase_name='tensorshape_with_float32_dtype_should_error',
value=nptf.TensorShape((1, 2)), dtype=nptf.float32, error=TypeError),
dict(testcase_name='tensorshape_with_large_dimension_should_be_int64',
value=nptf.TensorShape([2 ** 31]), out_dtype=nptf.int64),
dict(testcase_name=('tensorshape_with_large_dimension_with_int32'
'_dtype_should_error'),
value=nptf.TensorShape([2 ** 31]), dtype=nptf.int32, error=ValueError)
]

if JAX_MODE:
CONVERT_TO_TENSOR_TESTS += [
# Tests for converting onp arrays to tensors
dict(testcase_name='float32',
value=onp.float32(1.), out_dtype=nptf.float32),
dict(testcase_name='float32_with_int32_dtype',
value=onp.float32(1.), dtype=nptf.int32, out_dtype=nptf.int32),
dict(testcase_name='float32_with_int32_hint',
value=onp.float64(1.), dtype_hint=nptf.int32, out_dtype=nptf.int32),
dict(testcase_name='empty_ndarray',
value=onp.array([]), out_dtype=nptf.float64),
dict(testcase_name='empty_float32_ndarray',
value=onp.array([], dtype=onp.float32), out_dtype=nptf.float32),
dict(testcase_name='empty_float64_ndarray_with_int32_dtype',
value=onp.array([], dtype=onp.float64), out_dtype=nptf.float32,
dtype=nptf.float32),
# NumPy arrays get cast
dict(testcase_name='float64_ndarray_to_int32',
value=onp.array([1], dtype=onp.float64), out_dtype=nptf.int32,
dtype=nptf.int32),
dict(testcase_name='complex64_ndarray_to_int32',
value=onp.array([1], dtype=onp.complex64), out_dtype=nptf.int32,
dtype=nptf.int32),
dict(testcase_name='complex128_ndarray_to_float32',
value=onp.array([1], dtype=onp.complex128), out_dtype=nptf.float32,
dtype=nptf.float32),
# JAX will error when trying to change dtypes of tensors
dict(testcase_name='int32_tensor_with_int64_dtype_should_error',
value=1, in_dtype=nptf.int32, dtype=nptf.int64, error=TypeError),
dict(testcase_name='int32_tensor_with_float64_dtype_should_error',
value=1, in_dtype=nptf.int32, dtype=nptf.float64, error=TypeError),
]
else:
CONVERT_TO_TENSOR_TESTS += [
# NumPy should not error when trying to change dtypes of tensors
dict(testcase_name='int32_tensor_with_int64_dtype_should_not_error',
value=1, in_dtype=nptf.int32, dtype=nptf.int64,
out_dtype=nptf.int64),
dict(testcase_name='int32_tensor_with_float64_dtype_should_not_error',
value=1, in_dtype=nptf.int32, dtype=nptf.float64,
out_dtype=nptf.float64),
]


class NumpyTest(test_util.TestCase):

def _base_test_convert_to_tensor(self, nmpy):
convert_to_tensor = nptf.convert_to_tensor
self.assertEqual(
nmpy.complex64,
convert_to_tensor(nmpy.complex64(1 + 2j), dtype_hint=tf.int32).dtype)
self.assertEqual(
nmpy.complex64,
convert_to_tensor(nmpy.complex64(1 + 2j), dtype_hint=tf.float64).dtype)
self.assertEqual(nmpy.float64,
convert_to_tensor(1., dtype_hint=tf.int32).dtype)
self.assertEqual(
nmpy.int32, convert_to_tensor(1, dtype_hint=tf.int32).dtype)
self.assertEqual(nmpy.float32,
convert_to_tensor(1, dtype_hint=tf.float32).dtype)
self.assertEqual(nmpy.complex64,
convert_to_tensor(1., dtype_hint=tf.complex64).dtype)
self.assertEqual(
nmpy.int64, convert_to_tensor(1, dtype_hint=tf.int64).dtype)
self.assertEqual(
nmpy.int32,
convert_to_tensor(nmpy.int32(False), dtype_hint=tf.bool).dtype)

def test_convert_to_tensor(self):
self._base_test_convert_to_tensor(np)

def test_convert_to_tensor_numpy_array(self):
if not JAX_MODE:
self.skipTest('Check non-device arrays in JAX.')
self._base_test_convert_to_tensor(onp)

def test_convert_to_tensor_scalar_default(self):
convert_to_tensor = nptf.convert_to_tensor
self.assertEqual(np.complex128, convert_to_tensor(1. + 2j).dtype)
self.assertEqual(np.float32, convert_to_tensor(1.).dtype)
self.assertEqual(np.int32, convert_to_tensor(1).dtype)

def test_convert_to_tensor_dimension(self):
convert_to_tensor = nptf.convert_to_tensor
shape = tf1.Dimension(1)

tensor_shape = convert_to_tensor(shape)
self.assertNotIsInstance(tensor_shape, tf1.Dimension)

def test_convert_to_tensor_tensorshape(self):
convert_to_tensor = nptf.convert_to_tensor
shape = tf.TensorShape((1, 2))

tensor_shape = convert_to_tensor(shape)
for dim in tensor_shape:
self.assertNotIsInstance(dim, tf1.Dimension)

shape = tf.TensorShape((1, 2, 3))[:2]
tensor_shape = convert_to_tensor(shape)

for dim in tensor_shape:
self.assertNotIsInstance(dim, tf1.Dimension)
@parameterized.named_parameters(CONVERT_TO_TENSOR_TESTS)
def test_convert_to_tensor(self, value=None, out_value=None, out_dtype=None,
in_dtype=None, dtype=None, dtype_hint=None,
error=None):
if in_dtype:
value = nptf.convert_to_tensor(value, dtype=in_dtype)
if not error:
out = nptf.convert_to_tensor(value, dtype=dtype, dtype_hint=dtype_hint)
if out_dtype:
self.assertEqual(out_dtype, out.dtype)
if out_value is not None:
self.assertEqual(out_value, out)
else:
with self.assertRaises(error):
nptf.convert_to_tensor(value, dtype=dtype, dtype_hint=dtype_hint)

def test_concat_infers_dtype(self):
self.assertEqual(np.int32, nptf.concat([[1], []], 0).dtype)
self.assertEqual(np.float32, nptf.concat([[], [1]], 0).dtype)
self.assertEqual(np.float32, nptf.concat([np.array([1], np.float32),
np.array([1], np.float64)],
0).dtype)
self.assertEqual(np.float64, nptf.concat([np.array([1], np.float64),
np.array([1], np.float32)],
0).dtype)
self.assertEqual(np.float32, nptf.concat([[np.float32(1)], [np.float64(1)]],
0).dtype)
self.assertEqual(np.float32, nptf.concat([[np.float64(1)], [np.float32(1)]],
0).dtype)
# TODO(sharadmv): rewrite these tests when convert_to_tensor is fixed
self.assertEqual(np.int32, nptf.concat([[np.int32(1)], [np.int64(1)]],
0).dtype)
self.assertEqual(np.int32, nptf.concat([[np.int64(1)], [np.int32(1)]],
0).dtype)

@test_util.numpy_disable_gradient_test
def test_while_loop_gradients(self):
Expand Down
Loading

0 comments on commit 1b3e111

Please sign in to comment.