forked from tensorflow/probability
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rewrite nptf.convert_to_tensor to handle more cases and use a registry.
Main changes: - Add a registry system for tensor conversion. This enables converting objects like TensorShapes and Dimensions without making the convert_to_tensor code bloated and confusing. - Redo the dtype logic, closely following Tensorflow's dtype logic found here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/lib/core/py_seq_tensor.cc The logic mostly lives in a _default_convert_to_tensor function that handles most Python types, like bool, int, float, complex, and list, tuple. - Redo the tests to be parameterized and add additional tests that cover the various cases. - Register TensorShape and Dimension with convert_to_tensor An important consideration is how to handle convert_to_tensor when a "tensor" is passed in itself. In TF land, when we convert a tf.Tensor to a Tensor, the dtype must remain compatible. This restriction is extended to JAX, where if we try converting a DeviceArray to tensor, the dtype must remain compatible. On the other hand, for the NumPy backend, TFP code relies on more flexible dtype conversion for NumPy code, so we relax that restriction. Minor changes: - Copy dtype conversion logic from tf.range to nptf.range. - Adjust rewrite system to keep `import numpy as np` as NumPy in rewritten TFP code. This means that logically in both TF and JAX, np.ndarrays are the same, which is necessary for the dtype conversion logic to be consistent across backends. - Properly import Dimension for nptf.compat.v1. PiperOrigin-RevId: 312175371
- Loading branch information
1 parent
640f6e8
commit 1b3e111
Showing
5 changed files
with
472 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.