Skip to content

Commit

Permalink
Make the HMM distribution samples reproducible when setting the seed.
Browse files Browse the repository at this point in the history
There were two isses: (a) one of the underlying distributions was not
given the proper seed argument; and (b) sampling under tf.scan
produces non-deterministic behavior unless parallel_iterations=1.

PiperOrigin-RevId: 263012970
  • Loading branch information
axch authored and tensorflower-gardener committed Aug 12, 2019
1 parent 8d1400f commit b45e351
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 22 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,7 @@ py_test(
name = "hidden_markov_model_test",
size = "medium",
srcs = ["hidden_markov_model_test.py"],
shard_count = 4,
deps = [
# absl/testing:parameterized dep,
# numpy dep,
Expand Down
10 changes: 10 additions & 0 deletions tensorflow_probability/python/distributions/categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def testDtype(self):
self.assertEqual(dist.dtype, dtype)
self.assertEqual(dist.dtype, dist.sample(5).dtype)

def testReproducibility(self):
probs = tf.constant([0.6, 0.4], dtype=tf.float32)
dist = tfd.Categorical(probs=probs)
seed = tfp_test_util.test_seed()
s1 = self.evaluate(dist.sample(500, seed=seed))
if tf.executing_eagerly():
tf.random.set_seed(seed)
s2 = self.evaluate(dist.sample(500, seed=seed))
self.assertAllEqual(s1, s2)

def testUnknownShape(self):
logits = lambda l: tf1.placeholder_with_default( # pylint: disable=g-long-lambda
np.float32(l), shape=None)
Expand Down
22 changes: 16 additions & 6 deletions tensorflow_probability/python/distributions/hidden_markov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf

Expand Down Expand Up @@ -341,7 +342,7 @@ def num_states(self):

def _sample_n(self, n, seed=None):
with tf.control_dependencies(self._runtime_assertions):
seed = seed_stream.SeedStream(seed, salt="HiddenMarkovModel")
strm = seed_stream.SeedStream(seed, salt="HiddenMarkovModel")

num_states = self._num_states

Expand All @@ -358,7 +359,7 @@ def _sample_n(self, n, seed=None):
tf.reduce_prod(self.batch_shape_tensor()) //
tf.reduce_prod(self._initial_distribution.batch_shape_tensor()))
init_state = self._initial_distribution.sample(n * init_repeat,
seed=seed())
seed=strm())
init_state = tf.reshape(init_state, [n, batch_size])
# init_state :: n batch_size

Expand All @@ -370,7 +371,7 @@ def generate_step(state, _):
"""Take a single step in Markov chain."""

gen = self._transition_distribution.sample(n * transition_repeat,
seed=seed())
seed=strm())
# gen :: (n * transition_repeat) transition_batch

new_states = tf.reshape(gen,
Expand All @@ -385,9 +386,18 @@ def generate_step(state, _):
return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

def _scan_multiple_steps():
"""Take multiple steps with tf.scan."""
dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
hidden_states = tf.scan(generate_step, dummy_index,
initializer=init_state)
if seed is not None:
# Force parallel_iterations to 1 to ensure reproducibility
# b/139210489
hidden_states = tf.scan(generate_step, dummy_index,
initializer=init_state,
parallel_iterations=1)
else:
# Invoke default parallel_iterations behavior
hidden_states = tf.scan(generate_step, dummy_index,
initializer=init_state)

# TODO(b/115618503): add/use prepend_initializer to tf.scan
return tf.concat([[init_state],
Expand All @@ -411,7 +421,7 @@ def _scan_multiple_steps():
self._observation_distribution.batch_shape_tensor()[:-1]))

possible_observations = self._observation_distribution.sample(
[self._num_steps, observation_repeat * n])
[self._num_steps, observation_repeat * n], seed=strm())

inner_shape = self._observation_distribution.event_shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,32 @@ def test_non_scalar_transition_batch(self):
validate_args=True)
self.evaluate(model.mean())

def test_reproducibility(self):
initial_prob_data = tf.constant([0.6, 0.4], dtype=self.dtype)
transition_matrix_data = tf.constant([[0.6, 0.4],
[0.3, 0.7]], dtype=self.dtype)
observation_locs_data = tf.constant([0.0, 1.0], dtype=self.dtype)
observation_scale_data = tf.constant(0.5, dtype=self.dtype)

(initial_prob, transition_matrix,
observation_locs, observation_scale) = self.make_placeholders([
initial_prob_data, transition_matrix_data,
observation_locs_data, observation_scale_data])

[num_steps] = self.make_placeholders([30])
model = tfd.HiddenMarkovModel(tfd.Categorical(probs=initial_prob),
tfd.Categorical(probs=transition_matrix),
tfd.Normal(loc=observation_locs,
scale=observation_scale),
num_steps=num_steps)

seed = tfp_test_util.test_seed()
s1 = self.evaluate(model.sample(5, seed=seed))
if tf.executing_eagerly():
tf.random.set_seed(seed)
s2 = self.evaluate(model.sample(5, seed=seed))
self.assertAllEqual(s1, s2)

def test_consistency(self):
initial_prob_data = tf.constant([0.6, 0.4], dtype=self.dtype)
transition_matrix_data = tf.constant([[0.6, 0.4],
Expand All @@ -119,10 +145,11 @@ def test_consistency(self):
num_steps=num_steps,
validate_args=True)

self.run_test_sample_consistent_log_prob(self.evaluate, model,
num_samples=100000,
center=0.5, radius=0.5,
rtol=0.05)
self.run_test_sample_consistent_log_prob(
self.evaluate, model,
num_samples=100000,
center=0.5, radius=0.5,
rtol=0.05, seed=tfp_test_util.test_seed())

def test_broadcast_initial_probs(self):
initial_prob_data = tf.constant([0.6, 0.4], dtype=self.dtype)
Expand All @@ -143,10 +170,11 @@ def test_broadcast_initial_probs(self):
scale=observation_scale),
num_steps=num_steps)

self.run_test_sample_consistent_log_prob(self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=0.02)
self.run_test_sample_consistent_log_prob(
self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=0.02, seed=tfp_test_util.test_seed())

def test_broadcast_transitions(self):
initial_prob_data = tf.constant([0.6, 0.4], dtype=self.dtype)
Expand All @@ -170,10 +198,11 @@ def test_broadcast_transitions(self):
scale=observation_scale),
num_steps=num_steps)

self.run_test_sample_consistent_log_prob(self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=2e-2)
self.run_test_sample_consistent_log_prob(
self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=2e-2, seed=tfp_test_util.test_seed())

def test_broadcast_observations(self):
initial_prob_data = tf.constant([0.6, 0.4], dtype=self.dtype)
Expand All @@ -197,10 +226,11 @@ def test_broadcast_observations(self):
scale=observation_scale),
num_steps=num_steps)

self.run_test_sample_consistent_log_prob(self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=2e-2)
self.run_test_sample_consistent_log_prob(
self.evaluate, model,
num_samples=100000,
center=0.5, radius=1.,
rtol=2e-2, seed=tfp_test_util.test_seed())

def test_edge_case_sample_n_no_transitions(self):
initial_prob_data = tf.constant([0.5, 0.5], dtype=self.dtype)
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_probability/python/distributions/normal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,15 @@ def testNormalSample(self):
self.assertAllEqual(expected_samples_shape, samples.shape)
self.assertAllEqual(expected_samples_shape, sample_values.shape)

def testReproducibility(self):
dist = tfd.Normal(loc=[0.1, 0.5, 0.3], scale=[1.0, 5.0, 20.0])
seed = tfp_test_util.test_seed()
s1 = self.evaluate(dist.sample(500, seed=seed))
if tf.executing_eagerly():
tf.random.set_seed(seed)
s2 = self.evaluate(dist.sample(500, seed=seed))
self.assertAllEqual(s1, s2)

def testNormalFullyReparameterized(self):
mu = tf.constant(4.0)
sigma = tf.constant(3.0)
Expand Down

0 comments on commit b45e351

Please sign in to comment.