diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index e03199e5293f..88cb966e8cdd 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -736,14 +736,16 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N begin_state = self.begin_state() p = 0 + next_states = [] for i, cell in enumerate(self._cells): n = len(cell.state_shape) states = begin_state[p:p+n] p += n inputs, states = cell.unroll(length, inputs=inputs, begin_state=states, layout=layout, merge_outputs=None if i < num_cells-1 else merge_outputs) + next_states.extend(states) - return inputs, states + return inputs, next_states class DropoutCell(BaseRNNCell):