Skip to content

Commit

Permalink
Modifies [Exp]RelaxedOneHotCategorical to properly broadcast logits o…
Browse files Browse the repository at this point in the history
…ver temperature.

PiperOrigin-RevId: 238130648
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Mar 13, 2019
1 parent 07c406a commit f3ccb19
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,13 @@ def probs(self):
return self._probs

def _batch_shape_tensor(self):
return tf.shape(input=self._logits)[:-1]
return tf.broadcast_dynamic_shape(
tf.shape(input=self.temperature),
tf.shape(input=self.logits)[:-1])

def _batch_shape(self):
return self.logits.shape[:-1]
return tf.broadcast_static_shape(self.temperature.shape,
self.logits.shape[:-1])

def _event_shape_tensor(self):
return tf.shape(input=self.logits)[-1:]
Expand All @@ -233,27 +236,24 @@ def _event_shape(self):
return self.logits.shape.with_rank_at_least(1)[-1:]

def _sample_n(self, n, seed=None):
sample_shape = tf.concat([[n], tf.shape(input=self.logits)], 0)
logits = self.logits * tf.ones(sample_shape, dtype=self.dtype)
logits_2d = tf.reshape(logits, [-1, self.event_size])
# Uniform variates must be sampled from the open-interval `(0, 1)` rather
# than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny`
# because it is the smallest, positive, "normal" number. A "normal" number
# is such that the mantissa has an implicit leading 1. Normal, positive
# numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
# this case, a subnormal number (i.e., np.nextafter) can cause us to sample
# 0.
uniform_shape = tf.concat(
[[n], self.batch_shape_tensor(), self.event_shape_tensor()], 0)
uniform = tf.random.uniform(
shape=tf.shape(input=logits_2d),
shape=uniform_shape,
minval=np.finfo(self.dtype.as_numpy_dtype).tiny,
maxval=1.,
dtype=self.dtype,
seed=seed)
gumbel = -tf.math.log(-tf.math.log(uniform))
noisy_logits = (gumbel + logits_2d) / self._temperature_2d
samples = tf.nn.log_softmax(noisy_logits)
ret = tf.reshape(samples, sample_shape)
return ret
noisy_logits = (gumbel + self.logits) / self.temperature[..., tf.newaxis]
return tf.nn.log_softmax(noisy_logits)

def _log_prob(self, x):
x = self._assert_valid_sample(x)
Expand All @@ -264,22 +264,17 @@ def _log_prob(self, x):
x.shape != logits.shape):
logits = tf.ones_like(x, dtype=logits.dtype) * logits
x = tf.ones_like(logits, dtype=x.dtype) * x
logits_shape = tf.shape(input=tf.reduce_sum(input_tensor=logits, axis=[-1]))
logits_2d = tf.reshape(logits, [-1, self.event_size])
x_2d = tf.reshape(x, [-1, self.event_size])
# compute the normalization constant
k = tf.cast(self.event_size, x.dtype)
log_norm_const = (
tf.math.lgamma(k) + (k - 1.) * tf.math.log(self.temperature))
# compute the unnormalized density
log_softmax = tf.nn.log_softmax(logits_2d - x_2d * self._temperature_2d)
log_softmax = tf.nn.log_softmax(
self.logits - x * self.temperature[..., tf.newaxis])
log_unnorm_prob = tf.reduce_sum(
input_tensor=log_softmax, axis=[-1], keepdims=False)
# combine unnormalized density with normalization constant
log_prob = log_norm_const + log_unnorm_prob
# Reshapes log_prob to be consistent with shape of user-supplied logits
ret = tf.reshape(log_prob, logits_shape)
return ret
return log_norm_const + log_unnorm_prob

def _assert_valid_sample(self, x):
if not self.validate_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ def testPdf(self):
self.assertAllClose(expected_pdf, pdf)


def analytical_pdf(x, temperature, logits):
# analytical density of RelaxedOneHotCategorical
temperature = np.reshape(temperature, (-1, 1))
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
k = logits.shape[-1]
p = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
term1 = gamma(k) * np.power(temperature, k-1)
term2 = np.sum(p / (np.power(x, temperature)), axis=-1, keepdims=True)
term3 = np.prod(p / (np.power(x, temperature+1)), axis=-1, keepdims=True)
expected_pdf = term1 * np.power(term2, -k) * term3
return expected_pdf


@test_util.run_all_in_graph_and_eager_modes
class RelaxedOneHotCategoricalTest(tf.test.TestCase):

Expand All @@ -73,6 +87,13 @@ def testLogits(self):
self.assertAllClose(logits, self.evaluate(dist._distribution.logits))
self.assertAllEqual([3], dist._distribution.logits.shape)

def testParamBroadcasting(self):
temperature = [1.0, 1.4]
logits = [2.0, 3.0, -4.0]
dist = tfd.RelaxedOneHotCategorical(temperature, logits)
self.assertAllEqual([2], dist.batch_shape)
self.assertAllEqual([3], dist.event_shape)

def testSample(self):
temperature = 1.4
# single logit
Expand All @@ -92,19 +113,6 @@ def testSample(self):
self.assertAllEqual([5, 4, 1, 3], self.evaluate(dist.sample(5)).shape)

def testPdf(self):
def analytical_pdf(x, temperature, logits):
# analytical density of RelaxedOneHotCategorical
temperature = np.reshape(temperature, (-1, 1))
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
k = logits.shape[1]
p = np.exp(logits)/np.sum(np.exp(logits), axis=1, keepdims=True)
term1 = gamma(k)*np.power(temperature, k-1)
term2 = np.sum(p/(np.power(x, temperature)), axis=1, keepdims=True)
term3 = np.prod(p/(np.power(x, temperature+1)), axis=1, keepdims=True)
expected_pdf = term1*np.power(term2, -k)*term3
return expected_pdf

temperature = .4
logits = np.array([[.3, .1, .4]]).astype(np.float32)
dist = tfd.RelaxedOneHotCategorical(temperature, logits)
Expand All @@ -122,6 +130,15 @@ def analytical_pdf(x, temperature, logits):
expected_pdf = analytical_pdf(x, temperatures, logits)
self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)

# broadcast logits over temparatures
logits = np.array([.3, .1, .4]).astype(np.float32)
temperatures = np.array([0.4, 2.3]).astype(np.float32)
dist = tfd.RelaxedOneHotCategorical(temperatures, logits)
x = self.evaluate(dist.sample())
pdf = self.evaluate(dist.prob(x))
expected_pdf = analytical_pdf(x, temperatures, logits)
self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)

def testShapes(self):
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_relaxed_categorical(batch_shape, 10)
Expand Down

0 comments on commit f3ccb19

Please sign in to comment.