Skip to content

Commit

Permalink
fix bug in basic_rnn when layer_num > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
wan-wei committed Jan 23, 2018
1 parent d3d025e commit de9dbe4
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions tensorflow/layers/basic_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,31 @@ def rnn(rnn_type, inputs, length, hidden_size, layer_num=1, dropout_keep_prob=No
"""
if not rnn_type.startswith('bi'):
cell = get_cell(rnn_type, hidden_size, layer_num, dropout_keep_prob)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, sequence_length=length, dtype=tf.float32)
outputs, states = tf.nn.dynamic_rnn(cell, inputs, sequence_length=length, dtype=tf.float32)
if rnn_type.endswith('lstm'):
c, h = state
state = h
c = [state.c for state in states]
h = [state.h for state in states]
states = h
else:
cell_fw = get_cell(rnn_type, hidden_size, layer_num, dropout_keep_prob)
cell_bw = get_cell(rnn_type, hidden_size, layer_num, dropout_keep_prob)
outputs, state = tf.nn.bidirectional_dynamic_rnn(
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_bw, cell_fw, inputs, sequence_length=length, dtype=tf.float32
)
state_fw, state_bw = state
states_fw, states_bw = states
if rnn_type.endswith('lstm'):
c_fw, h_fw = state_fw
c_bw, h_bw = state_bw
state_fw, state_bw = h_fw, h_bw
c_fw = [state_fw.c for state_fw in states_fw]
h_fw = [state_fw.h for state_fw in states_fw]
c_bw = [state_bw.c for state_bw in states_bw]
h_bw = [state_bw.h for state_bw in states_bw]
states_fw, states_bw = h_fw, h_bw
if concat:
outputs = tf.concat(outputs, 2)
state = tf.concat([state_fw, state_bw], 1)
states = tf.concat([states_fw, states_bw], 1)
else:
outputs = outputs[0] + outputs[1]
state = state_fw + state_bw
return outputs, state
states = states_fw + states_bw
return outputs, states


def get_cell(rnn_type, hidden_size, layer_num=1, dropout_keep_prob=None):
Expand All @@ -74,20 +77,22 @@ def get_cell(rnn_type, hidden_size, layer_num=1, dropout_keep_prob=None):
Returns:
An RNN Cell
"""
if rnn_type.endswith('lstm'):
cell = tc.rnn.LSTMCell(num_units=hidden_size, state_is_tuple=True)
elif rnn_type.endswith('gru'):
cell = tc.rnn.GRUCell(num_units=hidden_size)
elif rnn_type.endswith('rnn'):
cell = tc.rnn.BasicRNNCell(num_units=hidden_size)
else:
raise NotImplementedError('Unsuported rnn type: {}'.format(rnn_type))
if dropout_keep_prob is not None:
cell = tc.rnn.DropoutWrapper(cell,
input_keep_prob=dropout_keep_prob,
output_keep_prob=dropout_keep_prob)
if layer_num > 1:
cell = tc.rnn.MultiRNNCell([cell]*layer_num, state_is_tuple=True)
return cell
cells = []
for i in range(layer_num):
if rnn_type.endswith('lstm'):
cell = tc.rnn.LSTMCell(num_units=hidden_size, state_is_tuple=True)
elif rnn_type.endswith('gru'):
cell = tc.rnn.GRUCell(num_units=hidden_size)
elif rnn_type.endswith('rnn'):
cell = tc.rnn.BasicRNNCell(num_units=hidden_size)
else:
raise NotImplementedError('Unsuported rnn type: {}'.format(rnn_type))
if dropout_keep_prob is not None:
cell = tc.rnn.DropoutWrapper(cell,
input_keep_prob=dropout_keep_prob,
output_keep_prob=dropout_keep_prob)
cells.append(cell)
cells = tc.rnn.MultiRNNCell(cells, state_is_tuple=True)
return cells


0 comments on commit de9dbe4

Please sign in to comment.