diff --git a/tflearn/layers/recurrent.py b/tflearn/layers/recurrent.py index 9a796647..0613c164 100644 --- a/tflearn/layers/recurrent.py +++ b/tflearn/layers/recurrent.py @@ -596,9 +596,10 @@ def __call__(self, inputs, state, scope=None): with tf.variable_scope(scope or type(self).__name__): # "GRUCell" with tf.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. - r, u = array_ops.split(1, 2, _linear([inputs, state], + r, u = array_ops.split(value=_linear([inputs, state], 2 * self._num_units, True, 1.0, self.weights_init, - self.trainable, self.restore, self.reuse)) + self.trainable, self.restore, self.reuse), + num_or_size_splits=2, axis=1) r, u = self._inner_activation(r), self._inner_activation(u) with tf.variable_scope("Candidate"): c = self._activation(