Skip to content

Commit

Permalink
Update rnn_cell.py
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Apr 4, 2017
1 parent 911f747 commit 5c69214
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,14 +736,16 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
begin_state = self.begin_state()

p = 0
next_states = []
for i, cell in enumerate(self._cells):
n = len(cell.state_shape)
states = begin_state[p:p+n]
p += n
inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout,
merge_outputs=None if i < num_cells-1 else merge_outputs)
next_states.extend(states)

return inputs, states
return inputs, next_states


class DropoutCell(BaseRNNCell):
Expand Down

0 comments on commit 5c69214

Please sign in to comment.