Skip to content

Commit

Permalink
Temporarily suppress GaussianProcessRegressionModel in stochastic_pro…
Browse files Browse the repository at this point in the history
…cess_properties_test.

PiperOrigin-RevId: 281353351
  • Loading branch information
axch authored and tensorflower-gardener committed Nov 19, 2019
1 parent 71882e3 commit e8025e7
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def gaussian_process_regression_models(draw,
feature_ndims=feature_ndims,
enable_vars=enable_vars,
name='index_points'))
hp.note('Index points: {}'.format(index_points))

observation_index_points = draw(
kernel_hps.kernel_input(
Expand All @@ -206,6 +207,7 @@ def gaussian_process_regression_models(draw,
feature_ndims=feature_ndims,
enable_vars=enable_vars,
name='observation_index_points'))
hp.note('Observation index points: {}'.format(observation_index_points))

observations = draw(kernel_hps.kernel_input(
batch_shape=compatible_batch_shape,
Expand All @@ -218,12 +220,15 @@ def gaussian_process_regression_models(draw,
feature_ndims=0,
enable_vars=enable_vars,
name='observations'))
hp.note('Observations: {}'.format(observations))

params = draw(broadcasting_params(
'GaussianProcessRegressionModel',
compatible_batch_shape,
event_dim=event_dim,
enable_vars=enable_vars))
hp.note('Params: {}'.format(params))

gp = tfd.GaussianProcessRegressionModel(
kernel=k,
index_points=index_points,
Expand Down Expand Up @@ -296,6 +301,10 @@ class StochasticProcessParamsAreVarsTest(test_util.TestCase):
hp.HealthCheck.filter_too_much,
hp.HealthCheck.data_too_large])
def testProcess(self, process_name, data):
if process_name == 'GaussianProcessRegressionModel':
import unittest # pylint: disable=g-import-not-at-top
raise unittest.case.SkipTest('b/144181034')

if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
return
seed = test_util.test_seed()
Expand Down Expand Up @@ -332,6 +341,7 @@ def testProcess(self, process_name, data):
'method `sample` of `{}`'.format(process),
max_permissible=excessive_usage_count(process_name)):
sample = process.sample(seed=seed)
hp.note('Drew sample {}'.format(sample))
if process.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
grads = tape.gradient(sample, process.variables)
for grad, var in zip(grads, process.variables):
Expand Down

0 comments on commit e8025e7

Please sign in to comment.