Skip to content

Commit

Permalink
Internal refactoring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 250919397
  • Loading branch information
edloper authored and tensorflower-gardener committed May 31, 2019
1 parent a4e1cb7 commit 4fa667e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
43 changes: 43 additions & 0 deletions tensorflow_probability/python/layers/internal/tensor_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,27 @@
import tensorflow as tf

from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import type_spec # pylint: disable=g-direct-tensorflow-import


__all__ = [
'TensorTuple',
]


# TODO(b/133606651) Delete _to_components, _from_components, etc. once
# CompositeTensor is refactored to use _type_spec.
class TensorTuple(composite_tensor.CompositeTensor):
"""`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""

def __init__(self, sequence):
super(TensorTuple, self).__init__()
self._sequence = tuple(tf.convert_to_tensor(value=x) for x in sequence)

@property
def _type_spec(self):
return TensorTupleSpec(map(type_spec.TypeSpec.from_value, self._sequence))

def _to_components(self):
return self._sequence

Expand All @@ -45,6 +52,7 @@ def _from_components(cls, components):
def _shape_invariant_to_components(self, shape=None):
raise NotImplementedError('TensorTuple._shape_invariant_to_components')

@property
def _is_graph_tensor(self):
return any(hasattr(x, 'graph') for x in self._sequence)

Expand All @@ -62,3 +70,38 @@ def __repr__(self):

def __str__(self):
return str(self._sequence)


class TensorTupleSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `TensorTuple`."""
__slots__ = ['_specs']

@property
def value_type(self):
return TensorTuple

def __init__(self, tensor_specs):
self._specs = tuple(tensor_specs)

def _serialize(self):
return (self._specs,)

@property
def _component_specs(self):
return self._specs

def _to_components(self, value):
return value._sequence # pylint: disable=protected-access

def _from_components(self, tensor_list):
return TensorTuple(tensor_list)

def _batch(self, batch_size):
# pylint: disable=protected-access
return TensorTupleSpec([spec._batch(batch_size) for spec in self._specs])

def _unbatch(self):
# pylint: disable=protected-access
return TensorTupleSpec([spec._unbatch() for spec in self._specs])


Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_getitem(self):
self.assertLen(y, 3)
for i in range(3):
self.assertAllEqual(x[i], tf.get_static_value(y[i]))
self.assertEqual(not(tf.executing_eagerly()), y._is_graph_tensor())
self.assertEqual(not(tf.executing_eagerly()), y._is_graph_tensor)

def test_to_from(self):
x = MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
Expand Down

0 comments on commit 4fa667e

Please sign in to comment.