Skip to content

Commit

Permalink
Zoneout: weighted sum with previous state to restore random noise exp…
Browse files Browse the repository at this point in the history
…ectation at inference time
  • Loading branch information
nikita-smetanin committed May 15, 2018
1 parent 0b3a267 commit 69022f5
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions tacotron/models/zoneout_LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,38 +170,32 @@ def __call__(self, inputs, state, scope=None):
w_f_diag * c_prev) + \
tf.sigmoid(i + w_i_diag * c_prev) * \
self.activation(j)
if self.is_training and self.zoneout_factor_cell > 0.0:
c = binary_mask_cell * c_prev + \
binary_mask_cell_complement * c_temp
else:
c = c_temp
else:
c_temp = c_prev * tf.sigmoid(f + self.forget_bias) + \
tf.sigmoid(i) * self.activation(j)
if self.is_training and self.zoneout_factor_cell > 0.0:
c = binary_mask_cell * c_prev + \
binary_mask_cell_complement * c_temp
else:
c = c_temp

if self.is_training and self.zoneout_factor_cell > 0.0:
c = binary_mask_cell * c_prev + \
binary_mask_cell_complement * c_temp
else:
c = (1.0 - self.zoneout_factor_cell) * c_temp + \
self.zoneout_factor_cell * c_prev

if self.cell_clip is not None:
c = tf.clip_by_value(c, -self.cell_clip, self.cell_clip)

# apply zoneout for output
if self.use_peepholes:
h_temp = tf.sigmoid(o + w_o_diag * c) * self.activation(c)
if self.is_training and self.zoneout_factor_output > 0.0:
h = binary_mask_output * h_prev + \
binary_mask_output_complement * h_temp
else:
h = h_temp
else:
h_temp = tf.sigmoid(o) * self.activation(c)
if self.is_training and self.zoneout_factor_output > 0.0:
h = binary_mask_output * h_prev + \
binary_mask_output_complement * h_temp
else:
h = h_temp

if self.is_training and self.zoneout_factor_output > 0.0:
h = binary_mask_output * h_prev + \
binary_mask_output_complement * h_temp
else:
h = (1.0 - self.zoneout_factor_output) * h_temp + \
self.zoneout_factor_output * h_prev

# apply prejection
if self.num_proj is not None:
Expand Down Expand Up @@ -262,4 +256,4 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
bias_term = tf.get_variable(
"Bias", [output_size],
initializer=tf.constant_initializer(bias_start))
return res + bias_term
return res + bias_term

0 comments on commit 69022f5

Please sign in to comment.