Skip to content

Commit

Permalink
Robustify AutoCompositeTensor; support deprecated things and obj.para…
Browse files Browse the repository at this point in the history
…meters

- callables decorated with TF's deprecation symbol get nested inside callables
  with *args and **kwargs, which broke our use of (std lib) `inspect`. Use
  `tf_inspect` instead, which correctly recurses into wrapped callables.
- some Distributions have kwargs that aren't attrs, but do appear in
  self.parameters; add this to the list of places we look for values during
  _to_components.

PiperOrigin-RevId: 350844814
  • Loading branch information
csuter authored and tensorflower-gardener committed Jan 8, 2021
1 parent 863271e commit 7b80f1c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tensorflow_probability/python/internal/auto_composite_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from __future__ import print_function

import functools
import inspect

import tensorflow.compat.v2 as tf

from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.saved_model import nested_structure_coder # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import

__all__ = [
'auto_composite_tensor',
Expand All @@ -42,7 +42,7 @@
def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None,
prefer_static_value=()):
"""Extract constructor kwargs to reconstruct `obj`."""
argspec = inspect.getfullargspec(obj.__init__)
argspec = tf_inspect.getfullargspec(obj.__init__)
if argspec.varargs or argspec.varkw:
raise ValueError(
'*args and **kwargs are not supported. Found `{}`'.format(argspec))
Expand All @@ -58,10 +58,15 @@ def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None,
kwargs[k] = getattr(obj, k)
elif hasattr(obj, '_' + k):
kwargs[k] = getattr(obj, '_' + k)
elif hasattr(obj, 'parameters') and k in obj.parameters:
kwargs[k] = obj.parameters[k]
else:
raise ValueError(
'Object did not have an attr corresponding to constructor argument '
'{k}. (Tried both `obj.{k}` and obj._{k}`).'.format(k=k))
f'Could not determine an appropriate value for field `{k}` in object '
' `{obj}`. Looked for \n'
' 1. an attr called `{k}`,\n'
' 2. an attr called `_{k}`,\n'
' 3. an entry in `obj.parameters` with key "{k}".')
if k in prefer_static_value and kwargs[k] is not None:
static_val = tf.get_static_value(kwargs[k])
if static_val is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ def body(obj):
maximum_iterations=3)
self.assertAllClose(3., result._tensor_param)

def test_parameters_lookup(self):

@tfp.experimental.auto_composite_tensor
class ThingWithParametersButNoAttrs(tfp.experimental.AutoCompositeTensor):

def __init__(self, a, b):
self.a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a')
self.b = tf.convert_to_tensor(b, dtype_hint=tf.float32, name='a')
self.parameters = dict(a=self.a, b=self.b)

t = ThingWithParametersButNoAttrs(1., 2.)
self.assertIsInstance(t, tf.__internal__.CompositeTensor)

ts = t._type_spec
components = ts._to_components(t)
self.assertAllEqualNested(components, dict(a=1., b=2.))

t2 = ts._from_components(components)
self.assertIsInstance(t2, ThingWithParametersButNoAttrs)


if __name__ == '__main__':
tf.enable_v2_behavior()
Expand Down

0 comments on commit 7b80f1c

Please sign in to comment.