Skip to content

Commit

Permalink
Adds tfp.math.cholesky_concat, which can be used to implicitly exte…
Browse files Browse the repository at this point in the history
…nd a n x n Xn = LnLn^T with Yzm new rows and columns (z = n + m), yielding a new Lz such that Xz = LzLz^T and Yzm forms the right-m and bottom-m rectangles of Xz.

This can be useful for extending an already-known cholesky in a backprop-friendly and efficient (N^2 instead of N^3) manner, e.g. for bayesian optimization.

Along the way, refactoring out a broadcasting_shapes hypothesis sampling strategy from distribution_properties_test -> tfp test_util.

PiperOrigin-RevId: 242551531
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Apr 8, 2019
1 parent b4a4802 commit 0c46081
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 83 deletions.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@ py_library(
"//tensorflow_probability/python/internal:assert_util",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:test_case",
],
)

Expand Down Expand Up @@ -1241,6 +1240,7 @@ py_test(
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:assert_util",
"//tensorflow_probability/python/internal:test_case",
],
)

Expand Down Expand Up @@ -1651,6 +1651,7 @@ py_test(
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:assert_util",
"//tensorflow_probability/python/internal:test_case",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
from tensorflow_probability.python.internal import test_util as tfp_test_util
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import

tfd = tfp.distributions
Expand All @@ -49,11 +50,6 @@ def hypothesis_max_examples():
return int(os.environ.get('TFP_HYPOTHESIS_MAX_EXAMPLES', 20))


def derandomize_hypothesis():
# Use --test_env=TFP_DERANDOMIZE_HYPOTHESIS=0 to get random coverage.
return bool(os.environ.get('TFP_DERANDOMIZE_HYPOTHESIS', 1))


MUTEX_PARAMS = [
set(['logits', 'probs']),
set(['rate', 'log_rate']),
Expand Down Expand Up @@ -114,55 +110,6 @@ def instantiable_dists():
# pylint: disable=no-value-for-parameter


def rank_only_shapes(mindims, maxdims):
return hps.integers(
min_value=mindims, max_value=maxdims).map(tf.TensorShape(None).with_rank)


def compute_rank_and_fullsize_reqd(draw, batch_shape, current_batch_shape,
is_last_param):
"""Returns a param rank and a list of bools for full-size-required by axis.
Args:
draw: Hypothesis data sampler.
batch_shape: Target broadcasted batch shape.
current_batch_shape: Broadcasted batch shape of params selected thus far.
This is ignored for non-last parameters.
is_last_param: bool indicator of whether this is the last param (in which
case, we must achieve the target batch_shape).
Returns:
param_batch_rank: Sampled rank for this parameter.
force_fullsize_dim: `param_batch_rank`-sized list of bool indicating whether
the corresponding axis of the parameter must be full-sized (True) or is
allowed to be 1 (i.e., broadcast) (False).
"""
batch_rank = batch_shape.ndims
if is_last_param:
# We must force full size dim on any mismatched axes, and proper rank.
full_rank_current = tf.broadcast_static_shape(
current_batch_shape, tf.TensorShape([1] * batch_rank))
# Identify axes in which the target shape is not yet matched.
axis_is_mismatched = [
full_rank_current[i] != batch_shape[i] for i in range(batch_rank)
]
min_rank = batch_rank
if current_batch_shape.ndims == batch_rank:
# Current rank might be already correct, but we could have a case like
# batch_shape=[4,3,2] and current_batch_shape=[4,1,2], in which case
# we must have at least 2 axes on this param's batch shape.
min_rank -= (axis_is_mismatched + [True]).index(True)
param_batch_rank = draw(rank_only_shapes(min_rank, batch_rank)).ndims
# Get the last param_batch_rank (possibly 0!) items.
force_fullsize_dim = axis_is_mismatched[batch_rank - param_batch_rank:]
else:
# There are remaining params to be drawn, so we will be able to force full
# size axes on subsequent params.
param_batch_rank = draw(rank_only_shapes(0, batch_rank)).ndims
force_fullsize_dim = [False] * param_batch_rank
return param_batch_rank, force_fullsize_dim


@hps.composite
def broadcasting_shapes(draw, batch_shape, param_names):
"""Draws a set of parameter batch shapes that broadcast to `batch_shape`.
Expand All @@ -182,30 +129,9 @@ def broadcasting_shapes(draw, batch_shape, param_names):
param_batch_shapes: `dict` of `str->tf.TensorShape` where the set of
shapes broadcast to `batch_shape`. The shapes are fully defined.
"""
batch_rank = batch_shape.ndims
result = {}
remaining_params = set(param_names)
current_batch_shape = tf.TensorShape([])
while remaining_params:
next_param = draw(hps.one_of(map(hps.just, remaining_params)))
remaining_params.remove(next_param)
param_batch_rank, force_fullsize_dim = compute_rank_and_fullsize_reqd(
draw,
batch_shape,
current_batch_shape,
is_last_param=not remaining_params)

# Get the last param_batch_rank (possibly 0!) dimensions.
param_batch_shape = batch_shape[batch_rank - param_batch_rank:].as_list()
for i, force_fullsize in enumerate(force_fullsize_dim):
if not force_fullsize and draw(hps.booleans()):
# Choose to make this param broadcast against some other param.
param_batch_shape[i] = 1
param_batch_shape = tf.TensorShape(param_batch_shape)
current_batch_shape = tf.broadcast_static_shape(current_batch_shape,
param_batch_shape)
result[next_param] = param_batch_shape
return result
n = len(param_names)
return dict(zip(draw(hps.permutations(param_names)),
draw(tfp_test_util.broadcasting_shapes(batch_shape, n))))


@hps.composite
Expand Down Expand Up @@ -290,8 +216,8 @@ def stringify_slices(slices):

@hps.composite
def batch_shapes(draw, min_ndims=0, max_ndims=3, min_lastdimsize=1):
shape = draw(rank_only_shapes(min_ndims, max_ndims))
rank = shape.ndims
rank = draw(hps.integers(min_value=min_ndims, max_value=max_ndims))
shape = tf.TensorShape(None).with_rank(rank)
if rank > 0:

def resize_lastdim(x):
Expand Down Expand Up @@ -619,7 +545,7 @@ def _run_test(self, data):
deadline=None,
max_examples=hypothesis_max_examples(),
suppress_health_check=[hp.HealthCheck.too_slow],
derandomize=derandomize_hypothesis())
derandomize=tfp_test_util.derandomize_hypothesis())
def testDistributions(self, data):
if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'): return
self._run_test(data)
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_probability/python/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ py_test(

py_library(
name = "test_case",
testonly = 1,
srcs = ["test_case.py"],
deps = [
# numpy dep,
Expand All @@ -170,10 +171,12 @@ py_test(

py_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.py"],
srcs_version = "PY2AND3",
deps = [
# absl/flags dep,
# hypothesis dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/distributions:seed_stream",
Expand Down
90 changes: 90 additions & 0 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,28 @@

# Dependency imports
from absl import flags
import hypothesis.strategies as hps
import numpy as np
import six

import tensorflow as tf
from tensorflow_probability.python.distributions import seed_stream

__all__ = [
'broadcasting_shapes',
'derandomize_hypothesis',
'test_seed',
'test_seed_stream',
'DiscreteScalarDistributionTestHelpers',
'VectorDistributionTestHelpers',
]


def derandomize_hypothesis():
# Use --test_env=TFP_DERANDOMIZE_HYPOTHESIS=0 to get random coverage.
return os.environ.get('TFP_DERANDOMIZE_HYPOTHESIS', 1) in (0, '0')


FLAGS = flags.FLAGS

flags.DEFINE_bool('vary_seed', False,
Expand All @@ -47,6 +55,88 @@
'Takes precedence over --vary-seed when both appear.'))


def _compute_rank_and_fullsize_reqd(draw, target_shape, current_shape, is_last):
"""Returns a param rank and a list of bools for full-size-required by axis.
Args:
draw: Hypothesis data sampler.
target_shape: `tf.TensorShape`, the target broadcasted shape.
current_shape: `tf.TensorShape`, the broadcasted shape of the shapes
selected thus far. This is ignored for non-last shapes.
is_last: bool indicator of whether this is the last shape (in which case, we
must achieve the target shape).
Returns:
next_rank: Sampled rank for the next shape.
force_fullsize_dim: `next_rank`-sized list of bool indicating whether the
corresponding axis of the shape must be full-sized (True) or is allowed to
be 1 (i.e., broadcast) (False).
"""
target_rank = target_shape.ndims
if is_last:
# We must force full size dim on any mismatched axes, and proper rank.
full_rank_current = tf.broadcast_static_shape(
current_shape, tf.TensorShape([1] * target_rank))
# Identify axes in which the target shape is not yet matched.
axis_is_mismatched = [
full_rank_current[i] != target_shape[i] for i in range(target_rank)
]
min_rank = target_rank
if current_shape.ndims == target_rank:
# Current rank might be already correct, but we could have a case like
# batch_shape=[4,3,2] and current_batch_shape=[4,1,2], in which case
# we must have at least 2 axes on this param's batch shape.
min_rank -= (axis_is_mismatched + [True]).index(True)
next_rank = draw(
hps.integers(min_value=min_rank, max_value=target_rank))
# Get the last param_batch_rank (possibly 0!) items.
force_fullsize_dim = axis_is_mismatched[target_rank - next_rank:]
else:
# There are remaining params to be drawn, so we will be able to force full
# size axes on subsequent params.
next_rank = draw(hps.integers(min_value=0, max_value=target_rank))
force_fullsize_dim = [False] * next_rank
return next_rank, force_fullsize_dim


@hps.composite
def broadcasting_shapes(draw, target_shape, n):
"""Draws a set of `n` shapes that broadcast to `target_shape`.
For each shape we need to choose its rank, and whether or not each axis i is 1
or target_shape[i]. This function chooses a set of `n` shapes that have
possibly mismatched ranks, and possibly broadcasting axes, with the promise
that the broadcast of the set of all shapes matches `target_shape`.
Args:
draw: Hypothesis sampler.
target_shape: The target (fully-defined) batch shape.
n: `int`, the number of shapes to draw.
Returns:
shapes: Sequence of `tf.TensorShape` such that the set of shapes broadcast
to `target_shape`. The shapes are fully defined.
"""
target_shape = tf.TensorShape(target_shape)
target_rank = target_shape.ndims
result = []
current_shape = tf.TensorShape([])
for is_last in [False] * (n-1) + [True]:
next_rank, force_fullsize_dim = _compute_rank_and_fullsize_reqd(
draw, target_shape, current_shape, is_last=is_last)

# Get the last next_rank (possibly 0!) dimensions.
next_shape = target_shape[target_rank - next_rank:].as_list()
for i, force_fullsize in enumerate(force_fullsize_dim):
if not force_fullsize and draw(hps.booleans()):
# Choose to make this param broadcast against some other param.
next_shape[i] = 1
next_shape = tf.TensorShape(next_shape)
current_shape = tf.broadcast_static_shape(current_shape, next_shape)
result.append(next_shape)
return result


def test_seed(hardcoded_seed=None, set_eager_seed=True):
"""Returns a command-line-controllable PRNG seed for unit tests.
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_probability/python/math/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ py_test(
name = "linalg_test",
size = "small",
srcs = ["linalg_test.py"],
shard_count = 3,
deps = [
# hypothesis dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
],
)

Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_probability.python.math.interpolation import batch_interp_regular_1d_grid
from tensorflow_probability.python.math.interpolation import batch_interp_regular_nd_grid
from tensorflow_probability.python.math.interpolation import interp_regular_1d_grid
from tensorflow_probability.python.math.linalg import cholesky_concat
from tensorflow_probability.python.math.linalg import lu_matrix_inverse
from tensorflow_probability.python.math.linalg import lu_reconstruct
from tensorflow_probability.python.math.linalg import lu_solve
Expand All @@ -44,6 +45,7 @@
_allowed_symbols = [
'batch_interp_regular_1d_grid',
'batch_interp_regular_nd_grid',
'cholesky_concat',
'clip_by_value_preserve_gradient',
'custom_gradient',
'dense_to_sparse',
Expand Down
Loading

0 comments on commit 0c46081

Please sign in to comment.