Skip to content

Commit

Permalink
Add flag to GaussianProcess, such that it can retain an event_shape o…
Browse files Browse the repository at this point in the history
…f [1].

Previously, if the number of (per-batch) index_points was statically
determinable and found to be 1, we would yield a Normal marginal distribution,
instead of an MVN. In some cases, users want to have a consistent event rank,
irrespective of # of index_points. We enable this behavior backward compatibly
by introducing a flag, `always_yield_multivariate_normal`, set by default to
False.

PiperOrigin-RevId: 434905157
  • Loading branch information
csuter authored and tensorflower-gardener committed Mar 16, 2022
1 parent 7ff8499 commit c7ab80d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(self,
marginal_fn=None,
cholesky_fn=None,
jitter=1e-6,
always_yield_multivariate_normal=False,
validate_args=False,
allow_nan_stats=False,
parameters=None,
Expand Down Expand Up @@ -294,6 +295,10 @@ def __init__(self,
`marginal_fn` and `cholesky_fn` is None.
This argument is ignored if `cholesky_fn` is set.
Default value: `1e-6`.
always_yield_multivariate_normal: If `False` (the default), we produce a
scalar `Normal` distribution when the number of `index_points` is
statically known to be `1`. If `True`, we avoid this behavior, ensuring
that the event shape will retain the `1` from `index_points`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
Expand Down Expand Up @@ -353,6 +358,7 @@ def __init__(self,
else:
self._marginal_fn = marginal_fn

self._always_yield_multivariate_normal = always_yield_multivariate_normal
with tf.name_scope('init'):
super(GaussianProcess, self).__init__(
dtype=dtype,
Expand All @@ -375,6 +381,9 @@ def _is_univariate_marginal(self, index_points):
multivariate. In the case of dynamic shape in the number of index points,
defaults to "multivariate" since that's the best we can do.
"""
if self._always_yield_multivariate_normal:
return False

num_index_points = tf.compat.dimension_value(
index_points.shape[-(self.kernel.feature_ndims + 1)])
if num_index_points is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(self,
mean_fn=None,
cholesky_fn=None,
jitter=1e-6,
always_yield_multivariate_normal=False,
validate_args=False,
allow_nan_stats=False,
name='GaussianProcessRegressionModel',
Expand Down Expand Up @@ -456,6 +457,10 @@ def __init__(self,
matrix to ensure positive definiteness of the covariance matrix.
This argument is ignored if `cholesky_fn` is set.
Default value: `1e-6`.
always_yield_multivariate_normal: If `False` (the default), we produce a
scalar `Normal` distribution when the number of `index_points` is
statically known to be `1`. If `True`, we avoid this behavior, ensuring
that the event shape will retain the `1` from `index_points`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
Expand Down Expand Up @@ -571,6 +576,7 @@ def conditional_mean_fn(x):
index_points=index_points,
cholesky_fn=cholesky_fn,
jitter=jitter,
always_yield_multivariate_normal=always_yield_multivariate_normal,
# What the GP super class calls "observation noise variance" we call
# here the "predictive noise variance". We use the observation noise
# variance for the fit/solve process above, and predictive for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,23 @@ def testUnivariateLogProbWithIsMissing(self):
tf.convert_to_tensor([[lp[0, 0], 0.0], [0.0, 0.0], [0., lp[2, 1]]]),
gp.log_prob(x, is_missing=[[False, True], [True, True], [True, False]]))

def testAlwaysYieldMultivariateNormal(self):
gp = tfd.GaussianProcess(
kernel=psd_kernels.ExponentiatedQuadratic(),
index_points=tf.ones([5, 1, 2]),
always_yield_multivariate_normal=False,
)
self.assertAllEqual([5], self.evaluate(gp.batch_shape_tensor()))
self.assertAllEqual([], self.evaluate(gp.event_shape_tensor()))

gp = tfd.GaussianProcess(
kernel=psd_kernels.ExponentiatedQuadratic(),
index_points=tf.ones([5, 1, 2]),
always_yield_multivariate_normal=True,
)
self.assertAllEqual([5], self.evaluate(gp.batch_shape_tensor()))
self.assertAllEqual([1], self.evaluate(gp.event_shape_tensor()))


@test_util.test_all_tf_execution_regimes
class GaussianProcessStaticTest(_GaussianProcessTest, test_util.TestCase):
Expand Down

0 comments on commit c7ab80d

Please sign in to comment.