Skip to content

Commit

Permalink
RNN docstring fixes; pass maximum_iteration argument to `while_loop…
Browse files Browse the repository at this point in the history
…` in TF.
  • Loading branch information
fchollet committed May 7, 2018
1 parent c8728e4 commit db57395
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 46 deletions.
55 changes: 27 additions & 28 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,48 +2679,46 @@ def rnn(step_function, inputs, initial_states,
"""Iterates over the time dimension of a tensor.
# Arguments
step_function: RNN step function.
step_function:
Parameters:
inputs: tensor with shape `(samples, ...)` (no time dimension),
inputs: Tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
states: List of tensors.
Returns:
outputs: tensor with shape `(samples, output_dim)`
(no time dimension).
new_states: list of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
inputs: tensor of temporal data of shape `(samples, time, ...)`
outputs: Tensor with shape (samples, ...) (no time dimension),
new_states: Tist of tensors, same length and shapes
as 'states'.
inputs: Tensor of temporal data of shape (samples, time, ...)
(at least 3D).
initial_states: tensor with shape (samples, output_dim)
(no time dimension),
initial_states: Tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over the time
go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape `(samples, time, 1)`,
mask: Binary tensor with shape (samples, time),
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop (`while_loop` or `scan` depending on backend).
input_length: not relevant in the TensorFlow implementation.
Must be specified if using unrolling with Theano.
constants: A list of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
input_length: Static number of timesteps in the input.
# Returns
A tuple, `(last_output, outputs, new_states)`.
last_output: the latest output of the rnn, of shape `(samples, ...)`
outputs: tensor with shape `(samples, time, ...)` where each
entry `outputs[s, t]` is the output of the step function
at time `t` for sample `s`.
new_states: list of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
last_output: The latest output of the rnn, of shape `(samples, ...)`
outputs: Tensor with shape `(samples, time, ...)` where each
entry `outputs[s, t]` is the output of the step function
at time `t` for sample `s`.
new_states: List of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
# Raises
ValueError: if input dimension is less than 3.
ValueError: if `unroll` is `True` but input timestep is not a fixed number.
ValueError: if `mask` is provided (not `None`) but states is not provided
(`len(states)` == 0).
ValueError: If input dimension is less than 3.
ValueError: If `unroll` is `True`
but input timestep is not a fixed number.
ValueError: If `mask` is provided (not `None`)
but states is not provided (`len(states)` == 0).
"""
ndim = len(inputs.get_shape())
if ndim < 3:
Expand Down Expand Up @@ -2906,7 +2904,8 @@ def _step(time, output_ta_t, *states):
body=_step,
loop_vars=(time, output_ta) + states,
parallel_iterations=32,
swap_memory=True)
swap_memory=True,
maximum_iterations=input_length)
last_time = final_outputs[0]
output_ta = final_outputs[1]
new_states = final_outputs[2:]
Expand Down
39 changes: 21 additions & 18 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,37 +1287,40 @@ def rnn(step_function, inputs, initial_states,
"""Iterates over the time dimension of a tensor.
# Arguments
inputs: tensor of temporal data of shape (samples, time, ...)
(at least 3D).
step_function:
Parameters:
inputs: tensor with shape (samples, ...) (no time dimension),
inputs: Tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
states: List of tensors.
Returns:
outputs: tensor with shape (samples, ...) (no time dimension),
new_states: list of tensors, same length and shapes
outputs: Tensor with shape (samples, ...) (no time dimension),
new_states: Tist of tensors, same length and shapes
as 'states'.
initial_states: tensor with shape (samples, ...) (no time dimension),
inputs: Tensor of temporal data of shape (samples, time, ...)
(at least 3D).
initial_states: Tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over the time
go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
mask: binary tensor with shape (samples, time),
mask: Binary tensor with shape (samples, time),
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop (`while_loop` or `scan` depending on backend).
input_length: must be specified if using `unroll`.
constants: A list of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
input_length: Static number of timesteps in the input.
Must be specified if using `unroll`.
# Returns
A tuple (last_output, outputs, new_states).
last_output: the latest output of the rnn, of shape (samples, ...)
outputs: tensor with shape (samples, time, ...) where each
entry outputs[s, t] is the output of the step function
at time t for sample s.
new_states: list of tensors, latest states returned by
the step function, of shape (samples, ...).
last_output: The latest output of the rnn, of shape `(samples, ...)`
outputs: Tensor with shape `(samples, time, ...)` where each
entry `outputs[s, t]` is the output of the step function
at time `t` for sample `s`.
new_states: List of tensors, latest states returned by
the step function, of shape `(samples, ...)`.
"""
ndim = inputs.ndim
assert ndim >= 3, 'Input should be at least 3D.'
Expand Down

0 comments on commit db57395

Please sign in to comment.