Skip to content

Commit

Permalink
Update TFP internals to support JAX omnistaging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 324618716
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Aug 3, 2020
1 parent 1a5e8be commit 229e2c0
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 154 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ multi_substrate_py_library(
srcs = ["samplers.py"],
deps = [
":dtype_util",
":prefer_static",
# numpy dep,
# tensorflow dep,
],
Expand Down
73 changes: 37 additions & 36 deletions tensorflow_probability/python/internal/distribution_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -113,8 +113,10 @@ def shapes_from_loc_and_scale(loc, scale, name='shapes_from_loc_and_scale'):
loc = None # scalar loc is irrelevant to determining batch/event shape.
with tf.name_scope(name):
# Get event shape.
event_size = scale.range_dimension_tensor()
event_size_ = tf.get_static_value(event_size)
event_size = tf.compat.dimension_value(scale.range_dimension)
if event_size is None:
event_size = scale.range_dimension_tensor()
event_size_ = tf.get_static_value(ps.convert_to_shape_tensor(event_size))
loc_event_size_ = (None if loc is None
else tf.compat.dimension_value(loc.shape[-1]))

Expand All @@ -130,23 +132,26 @@ def shapes_from_loc_and_scale(loc, scale, name='shapes_from_loc_and_scale'):
if event_size_ is None:
event_shape = event_size[tf.newaxis]
else:
event_shape = tf.convert_to_tensor(
event_shape = ps.convert_to_shape_tensor(
np.reshape(event_size_, [1]), dtype=tf.int32, name='event_shape')

# Get batch shape.
batch_shape = scale.batch_shape_tensor()
batch_shape = scale.batch_shape
if not tensorshape_util.is_fully_defined(batch_shape):
batch_shape = scale.batch_shape_tensor()
else:
batch_shape = ps.convert_to_shape_tensor(batch_shape)
if loc is not None:
loc_batch_shape = tensorshape_util.with_rank_at_least(loc.shape, 1)[:-1]
if tensorshape_util.rank(
loc.shape) is None or not tensorshape_util.is_fully_defined(
loc_batch_shape):
if (tensorshape_util.rank(loc.shape) is None or
not tensorshape_util.is_fully_defined(loc_batch_shape)):
loc_batch_shape = tf.shape(loc)[:-1]
else:
loc_batch_shape = tf.convert_to_tensor(
loc_batch_shape = ps.convert_to_shape_tensor(
loc_batch_shape, dtype=tf.int32, name='loc_batch_shape')
# This is defined in the core util module.
batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) # pylint: disable=undefined-variable
batch_shape = tf.convert_to_tensor(
batch_shape = ps.broadcast_shape(batch_shape, loc_batch_shape)
batch_shape = ps.convert_to_shape_tensor(
batch_shape, dtype=tf.int32, name='batch_shape')

return batch_shape, event_shape
Expand Down Expand Up @@ -371,45 +376,45 @@ def move_dimension(x, source_idx, dest_idx):
ndims = prefer_static_rank(x)
dtype = dtype_util.common_dtype([source_idx, dest_idx],
dtype_hint=tf.int32)
source_idx = tf.convert_to_tensor(source_idx, dtype=dtype)
dest_idx = tf.convert_to_tensor(dest_idx, dtype=dtype)
source_idx = ps.convert_to_shape_tensor(source_idx, dtype=dtype)
dest_idx = ps.convert_to_shape_tensor(dest_idx, dtype=dtype)

# Handle negative indexing.
source_idx = pick_scalar_condition(source_idx < 0, ndims + source_idx,
source_idx)
dest_idx = pick_scalar_condition(dest_idx < 0, ndims + dest_idx, dest_idx)
source_idx = ps.where(source_idx < 0, ndims + source_idx, source_idx)
dest_idx = ps.where(dest_idx < 0, ndims + dest_idx, dest_idx)

# Construct the appropriate permutation of dimensions, depending
# whether the source is before or after the destination.
def move_left_permutation():
return prefer_static_value(
tf.concat([
tf.range(0, dest_idx, dtype=dtype), [source_idx],
tf.range(dest_idx, source_idx, dtype=dtype),
tf.range(source_idx + 1, ndims, dtype=dtype)
ps.concat([
ps.range(0, dest_idx, dtype=dtype),
[source_idx],
ps.range(dest_idx, source_idx, dtype=dtype),
ps.range(source_idx + 1, ndims, dtype=dtype)
],
axis=0))

def move_right_permutation():
return prefer_static_value(
tf.concat([
tf.range(0, source_idx, dtype=dtype),
tf.range(source_idx + 1, dest_idx + 1, dtype=dtype), [source_idx],
tf.range(dest_idx + 1, ndims, dtype=dtype)
ps.concat([
ps.range(0, source_idx, dtype=dtype),
ps.range(source_idx + 1, dest_idx + 1, dtype=dtype),
[source_idx],
ps.range(dest_idx + 1, ndims, dtype=dtype)
],
axis=0))

def x_permuted():
return tf.transpose(
a=x,
perm=prefer_static.cond(source_idx < dest_idx,
move_right_permutation,
move_left_permutation))
perm=ps.cond(source_idx < dest_idx,
move_right_permutation,
move_left_permutation))

# One final conditional to handle the special case where source
# and destination indices are equal.
return prefer_static.cond(tf.equal(source_idx, dest_idx),
lambda: x, x_permuted)
return ps.cond(ps.equal(source_idx, dest_idx), lambda: x, x_permuted)


def assert_integer_form(x,
Expand Down Expand Up @@ -1356,6 +1361,7 @@ def expand_to_vector(x, tensor_name=None, op_name=None, validate_args=False):
vector: a 1-D `Tensor`.
"""
with tf.name_scope(op_name or 'expand_to_vector'):
x_orig = x
x = tf.convert_to_tensor(x, name='x')
ndims = tensorshape_util.rank(x.shape)

Expand All @@ -1373,13 +1379,8 @@ def expand_to_vector(x, tensor_name=None, op_name=None, validate_args=False):

elif ndims == 0:
# Definitely expand ndims from 0 to 1.
x_const = tf.get_static_value(x)
if x_const is not None:
return tf.convert_to_tensor(
dtype_util.as_numpy_dtype(x.dtype)([x_const]), name=tensor_name)

else:
return tf.reshape(x, [1])
return ps.convert_to_shape_tensor(
ps.reshape(x_orig, [1]), name=tensor_name)

elif ndims != 1:
raise ValueError('Input is neither scalar nor vector.')
Expand Down
20 changes: 19 additions & 1 deletion tensorflow_probability/python/internal/prefer_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,17 @@ def _numpy_dtype(dtype):

def _get_static_value(pred):
"""Helper function for getting static values from maybe-tensor objects."""
if JAX_MODE:
try:
return np.asarray(pred)
except: # JAX sometimes raises raw Exception in __array__. # pylint: disable=bare-except
return None
if tf.is_tensor(pred):
pred_value = tf.get_static_value(tf.convert_to_tensor(pred))

# TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
# pylint: disable=protected-access
if not JAX_MODE and pred_value is None:
if pred_value is None:
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
Expand All @@ -125,6 +130,16 @@ def _get_static_predicate(pred):
return pred_value


def _convert_to_shape_tensor_jax(value, dtype=None, dtype_hint=None, name=None): # pylint: disable=unused-argument
"""Converts vectors and scalars of `int`-like to `ndarray`."""
dtype = dtype_util.as_numpy_dtype(dtype or dtype_hint or np.int32)
try:
return np.array([int(v) for v in value], dtype=dtype)
except: # JAX throws raw Exception in some cases. # pylint: disable=bare-except
pass
return np.array(int(value), dtype=dtype)


def smart_where(condition, x_fn, y_fn):
"""As tf.where, but only calls x_fn/y_fn when condition not statically known.
Expand Down Expand Up @@ -415,6 +430,9 @@ def is_numpy(x):
cast = _prefer_static(tf.cast, nptf.cast)
ceil = _prefer_static(tf.math.ceil, nptf.math.ceil)
concat = _prefer_static(tf.concat, nptf.concat)
convert_to_shape_tensor = _prefer_static(
tf.convert_to_tensor,
_convert_to_shape_tensor_jax if JAX_MODE else tf.convert_to_tensor)
cumprod = _prefer_static(tf.math.cumprod, nptf.math.cumprod)
cumsum = _prefer_static(tf.math.cumsum, nptf.math.cumsum)
equal = _prefer_static(tf.equal, nptf.equal)
Expand Down
Loading

0 comments on commit 229e2c0

Please sign in to comment.