Skip to content

Commit

Permalink
Enable Multitask GP Regression Model to be rebuilt from Cholesky fact…
Browse files Browse the repository at this point in the history
…ors (for use in Vizier).

PiperOrigin-RevId: 508465252
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Feb 9, 2023
1 parent 82f8e01 commit 69aaccd
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ multi_substrate_py_library(
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/distributions:cholesky_util",
"//tensorflow_probability/python/experimental/linalg:linear_operator_unitary",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# ============================================================================
"""The MultiTaskGaussianProcessRegressionModel distribution class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import tensorflow.compat.v2 as tf
Expand All @@ -26,6 +22,7 @@
from tensorflow_probability.python.distributions import distribution
from tensorflow_probability.python.distributions import mvn_linear_operator
from tensorflow_probability.python.experimental.distributions import multitask_gaussian_process as mtgp
from tensorflow_probability.python.experimental.linalg import linear_operator_unitary
from tensorflow_probability.python.experimental.psd_kernels import multitask_kernel
from tensorflow_probability.python.internal import batch_shape_lib
from tensorflow_probability.python.internal import distribution_util
Expand Down Expand Up @@ -133,6 +130,52 @@ def _compute_observation_scale(
return observation_scale


def _scale_from_precomputed(precomputed_cholesky, kernel):
"""Rebuilds `observation_scale` from precomputed values."""
params, = tuple(precomputed_cholesky.values())
if 'tril' in precomputed_cholesky:
return tf.linalg.LinearOperatorLowerTriangular(
params['chol_tril'], is_non_singular=True)
if 'multitask' in precomputed_cholesky:
return tf.linalg.LinearOperatorKronecker(
[tf.linalg.LinearOperatorLowerTriangular(
params['chol_tril'], is_non_singular=True),
tf.linalg.LinearOperatorIdentity(
kernel.num_tasks, dtype=kernel.dtype)],
is_square=True,
is_non_singular=True)
if 'separable' in precomputed_cholesky:
diag_op = tf.linalg.LinearOperatorDiag(
params['diag'],
is_square=True,
is_non_singular=True,
is_positive_definite=True)
orthogonal_op = tf.linalg.LinearOperatorKronecker(
[linear_operator_unitary.LinearOperatorUnitary(orth)
for orth in params['kronecker_orths']],
is_square=True, is_non_singular=True)
return orthogonal_op.matmul(diag_op)
# This should not happen.
raise ValueError(
f'Unexpected value for `precompute_cholesky`: {precomputed_cholesky}.')


def _precomputed_from_scale(observation_scale, kernel):
"""Extracts expensive precomputed values."""
if isinstance(observation_scale, tf.linalg.LinearOperatorLowerTriangular):
return {'tril': {'chol_tril': observation_scale.tril}}
if isinstance(kernel, multitask_kernel.Independent):
base_kernel_chol_op = observation_scale.operators[0]
return {'multitask': {'chol_tril': base_kernel_chol_op.tril}}
if isinstance(kernel, multitask_kernel.Separable):
kronecker_op, diag_op = observation_scale.operators
kronecker_orths = [k.matrix for k in kronecker_op.operators]
return {'separable': {'kronecker_orths': kronecker_orths,
'diag': diag_op.diag}}
# This should not happen.
raise ValueError('Unexpected values for kernel and observation_scale.')


class MultiTaskGaussianProcessRegressionModel(
distribution.AutoCompositeTensorDistribution):
"""Posterior predictive in a conjugate Multi-task GP regression model."""
Expand Down Expand Up @@ -366,7 +409,9 @@ def precompute_regression_model(
cholesky_fn=None,
validate_args=False,
allow_nan_stats=False,
name='PrecomputedMultiTaskGaussianProcessRegressionModel'):
name='PrecomputedMultiTaskGaussianProcessRegressionModel',
_precomputed_divisor_matrix_cholesky=None,
_precomputed_solve_on_observation=None):
"""Returns a MTGaussianProcessRegressionModel with precomputed quantities.
This differs from the constructor by precomputing quantities associated with
Expand Down Expand Up @@ -462,6 +507,8 @@ def precompute_regression_model(
Default value: `False`.
name: Python `str` name prefixed to Ops created by this class.
Default value: 'PrecomputedGaussianProcessRegressionModel'.
_precomputed_divisor_matrix_cholesky: Internal parameter -- do not use.
_precomputed_solve_on_observation: Internal parameter -- do not use.
Returns
An instance of `MultiTaskGaussianProcessRegressionModel` with precomputed
quantities associated with observations.
Expand Down Expand Up @@ -497,7 +544,10 @@ def precompute_regression_model(
if not callable(mean_fn):
raise ValueError('`mean_fn` must be a Python callable')

if observations_is_missing is not None:
if _precomputed_divisor_matrix_cholesky is not None:
observation_scale = _scale_from_precomputed(
_precomputed_divisor_matrix_cholesky, kernel)
elif observations_is_missing is not None:
# If observations are missing, there's nothing we can do to preserve the
# operator structure, so densify.

Expand Down Expand Up @@ -537,8 +587,10 @@ def precompute_regression_model(
vec_diff = tf.where(vec_observations_is_missing,
tf.zeros([], dtype=vec_diff.dtype),
vec_diff)
solve_on_observations = observation_scale.solvevec(
observation_scale.solvevec(vec_diff), adjoint=True)
solve_on_observations = _precomputed_solve_on_observation
if solve_on_observations is None:
solve_on_observations = observation_scale.solvevec(
observation_scale.solvevec(vec_diff), adjoint=True)

def flattened_conditional_mean_fn(x):

Expand Down Expand Up @@ -566,6 +618,12 @@ def flattened_conditional_mean_fn(x):
allow_nan_stats=allow_nan_stats,
name=name)

# pylint: disable=protected-access
mtgprm._precomputed_divisor_matrix_cholesky = (
_precomputed_from_scale(observation_scale, kernel))
mtgprm._precomputed_solve_on_observation = solve_on_observations
# pylint: enable=protected-access

return mtgprm

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
# ============================================================================
"""Tests for MultiTaskGaussianProcessRegressionModel."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from unittest import mock
# Dependency imports

from absl.testing import parameterized
Expand Down Expand Up @@ -510,10 +507,29 @@ def testMeanVarianceAndCovariancePrecomputed(self):
observation_noise_variance=observation_noise_variance,
validate_args=True)

mock_cholesky_fn = mock.Mock(return_value=None)
rebuilt_precomputed_mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model(
kernel=multi_task_kernel,
index_points=index_points,
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance,
_precomputed_divisor_matrix_cholesky=precomputed_mtgprm._precomputed_divisor_matrix_cholesky,
_precomputed_solve_on_observation=precomputed_mtgprm._precomputed_solve_on_observation,
cholesky_fn=mock_cholesky_fn,
validate_args=True)
mock_cholesky_fn.assert_not_called()

rebuilt_precomputed_mtgprm = rebuilt_precomputed_mtgprm.copy(
cholesky_fn=None)
self.assertAllClose(self.evaluate(precomputed_mtgprm.variance()),
self.evaluate(mtgprm.variance()))
self.assertAllClose(self.evaluate(precomputed_mtgprm.mean()),
self.evaluate(mtgprm.mean()))
self.assertAllClose(self.evaluate(rebuilt_precomputed_mtgprm.variance()),
self.evaluate(mtgprm.variance()))
self.assertAllClose(self.evaluate(rebuilt_precomputed_mtgprm.mean()),
self.evaluate(mtgprm.mean()))

def testPrecomputedWithMasking(self):
num_tasks = 2
Expand All @@ -540,8 +556,15 @@ def testPrecomputedWithMasking(self):

kernel = exponentiated_quadratic.ExponentiatedQuadratic(
amplitude, length_scale)
multi_task_kernel = multitask_kernel.Independent(
num_tasks=num_tasks, base_kernel=kernel)
task_kernel_matrix = np.array([[6., 2.],
[2., 7.]],
dtype=np.float64)
task_kernel_matrix_linop = tf.linalg.LinearOperatorFullMatrix(
task_kernel_matrix)
multi_task_kernel = multitask_kernel.Separable(
num_tasks=num_tasks,
task_kernel_matrix_linop=task_kernel_matrix_linop,
base_kernel=kernel)
mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model(
kernel=multi_task_kernel,
index_points=index_points,
Expand All @@ -551,12 +574,30 @@ def testPrecomputedWithMasking(self):
observation_noise_variance=observation_noise_variance,
validate_args=True)

mock_cholesky_fn = mock.Mock(return_value=None)
rebuilt_mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model(
kernel=multi_task_kernel,
index_points=index_points,
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance,
_precomputed_divisor_matrix_cholesky=mtgprm._precomputed_divisor_matrix_cholesky,
_precomputed_solve_on_observation=mtgprm._precomputed_solve_on_observation,
cholesky_fn=mock_cholesky_fn,
validate_args=True)
mock_cholesky_fn.assert_not_called()

rebuilt_mtgprm = rebuilt_mtgprm.copy(cholesky_fn=None)
self.assertAllNotNan(mtgprm.mean())
self.assertAllNotNan(mtgprm.variance())
self.assertAllClose(self.evaluate(mtgprm.variance()),
self.evaluate(rebuilt_mtgprm.variance()))
self.assertAllClose(self.evaluate(mtgprm.mean()),
self.evaluate(rebuilt_mtgprm.mean()))

@test_util.disable_test_for_backend(
disable_numpy=True, disable_jax=True,
reason='Numpy and JAX have no notion of CompositeTensor/saved_model')
disable_numpy=True,
reason='Numpy has no notion of CompositeTensor/Pytree/saved_model')
def testPrecomputedCompositeTensor(self):
num_tasks = 3
amplitude = np.array([1., 2.], np.float64).reshape([2, 1])
Expand Down

0 comments on commit 69aaccd

Please sign in to comment.