Skip to content

Commit

Permalink
Make RNN api public.
Browse files Browse the repository at this point in the history
Change: 124305799
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Jun 8, 2016
1 parent 31bf8c1 commit bab4fc4
Show file tree
Hide file tree
Showing 23 changed files with 1,800 additions and 4 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changes Since Last Release

## Features & Improvements
* The RNN api is finally "official" (see, e.g., `tf.nn.dynamic_rnn`,
`tf.nn.rnn`, and the classes in `tf.nn.rnn_cell`).
* TensorBoard now has an Audio Dashboard, with associated audio summaries.
* TensorBoard now has a reload button, and supports auto-reloading
* TensorBoard scalar charts now show tooltips with more information
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
### `tf.nn.bidirectional_rnn(cell_fw, cell_bw, inputs, initial_state_fw=None, initial_state_bw=None, dtype=None, sequence_length=None, scope=None)` {#bidirectional_rnn}

Creates a bidirectional recurrent neural network.

Similar to the unidirectional case above (rnn) but takes input and builds
independent forward and backward RNNs with the final forward and backward
outputs depth-concatenated, such that the output will have the format
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
forward and backward cell must match. The initial state for both directions
is zero by default (but can be set optionally) and no intermediate states are
ever returned -- the network is fully unrolled for the given (passed in)
length(s) of the sequence(s) or completely unrolled if length(s) is not given.

##### Args:


* <b>`cell_fw`</b>: An instance of RNNCell, to be used for forward direction.
* <b>`cell_bw`</b>: An instance of RNNCell, to be used for backward direction.
* <b>`inputs`</b>: A length T list of inputs, each a tensor of shape
[batch_size, input_size].
* <b>`initial_state_fw`</b>: (optional) An initial state for the forward RNN.
This must be a tensor of appropriate type and shape
`[batch_size x cell_fw.state_size]`.
If `cell_fw.state_size` is a tuple, this should be a tuple of
tensors having shapes `[batch_size, s] for s in cell_fw.state_size`.
* <b>`initial_state_bw`</b>: (optional) Same as for `initial_state_fw`, but using
the corresponding properties of `cell_bw`.
* <b>`dtype`</b>: (optional) The data type for the initial state. Required if
either of the initial states are not provided.
* <b>`sequence_length`</b>: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences.
* <b>`scope`</b>: VariableScope for the created subgraph; defaults to "BiRNN"

##### Returns:

A tuple (outputs, output_state_fw, output_state_bw) where:
outputs is a length `T` list of outputs (one for each input), which
are depth-concatenated forward and backward outputs.
output_state_fw is the final state of the forward rnn.
output_state_bw is the final state of the backward rnn.

##### Raises:


* <b>`TypeError`</b>: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
* <b>`ValueError`</b>: If inputs is None or an empty list.

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
### `tf.nn.rnn(cell, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None)` {#rnn}

Creates a recurrent neural network specified by RNNCell `cell`.

##### The simplest form of RNN network generated is:

state = cell.zero_state(...)
outputs = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
return (outputs, state)

However, a few other options are available:

An initial state can be provided.
If the sequence_length vector is provided, dynamic calculation is performed.
This method of calculation does not compute the RNN steps past the maximum
sequence length of the minibatch (thus saving computational time),
and properly propagates the state at an example's sequence length
to the final state output.

The dynamic calculation performed is, at time t for batch row b,
(output, state)(b, t) =
(t >= sequence_length(b))
? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
: cell(input(b, t), state(b, t - 1))

##### Args:


* <b>`cell`</b>: An instance of RNNCell.
* <b>`inputs`</b>: A length T list of inputs, each a tensor of shape
[batch_size, input_size].
* <b>`initial_state`</b>: (optional) An initial state for the RNN.
If `cell.state_size` is an integer, this must be
a tensor of appropriate type and shape `[batch_size x cell.state_size]`.
If `cell.state_size` is a tuple, this should be a tuple of
tensors having shapes `[batch_size, s] for s in cell.state_size`.
* <b>`dtype`</b>: (optional) The data type for the initial state. Required if
initial_state is not provided.
* <b>`sequence_length`</b>: Specifies the length of each sequence in inputs.
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
* <b>`scope`</b>: VariableScope for the created subgraph; defaults to "RNN".

##### Returns:

A pair (outputs, state) where:
- outputs is a length T list of outputs (one for each input)
- state is the final state

##### Raises:


* <b>`TypeError`</b>: If `cell` is not an instance of RNNCell.
* <b>`ValueError`</b>: If `inputs` is `None` or an empty list, or if the input depth
(column size) cannot be inferred from inputs via shape inference.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
Operator adding input embedding to the given cell.

Note: in many cases it may be more efficient to not use this wrapper,
but instead concatenate the whole sequence of your inputs in time,
do the embedding on this batch-concatenated sequence, then split it and
feed into your RNN.
- - -

#### `tf.nn.rnn_cell.EmbeddingWrapper.__init__(cell, embedding_classes, embedding_size, initializer=None)` {#EmbeddingWrapper.__init__}

Create a cell with an added input embedding.

##### Args:


* <b>`cell`</b>: an RNNCell, an embedding will be put before its inputs.
* <b>`embedding_classes`</b>: integer, how many symbols will be embedded.
* <b>`embedding_size`</b>: integer, the size of the vectors we embed into.
* <b>`initializer`</b>: an initializer to use when creating the embedding;
if None, the initializer from variable scope or a default one is used.

##### Raises:


* <b>`TypeError`</b>: if cell is not an RNNCell.
* <b>`ValueError`</b>: if embedding_classes is not positive.


- - -

#### `tf.nn.rnn_cell.EmbeddingWrapper.output_size` {#EmbeddingWrapper.output_size}

Integer: size of outputs produced by this cell.


- - -

#### `tf.nn.rnn_cell.EmbeddingWrapper.state_size` {#EmbeddingWrapper.state_size}




- - -

#### `tf.nn.rnn_cell.EmbeddingWrapper.zero_state(batch_size, dtype)` {#EmbeddingWrapper.zero_state}

Return zero-filled state tensor(s).

##### Args:


* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
* <b>`dtype`</b>: the data type to use for the state.

##### Returns:

If `state_size` is an int, then the return value is a `2-D` tensor of
shape `[batch_size x state_size]` filled with zeros.

If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.


Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
Operator adding an output projection to the given cell.

Note: in many cases it may be more efficient to not use this wrapper,
but instead concatenate the whole sequence of your outputs in time,
do the projection on this batch-concatenated sequence, then split it
if needed or directly feed into a softmax.
- - -

#### `tf.nn.rnn_cell.OutputProjectionWrapper.__init__(cell, output_size)` {#OutputProjectionWrapper.__init__}

Create a cell with output projection.

##### Args:


* <b>`cell`</b>: an RNNCell, a projection to output_size is added to it.
* <b>`output_size`</b>: integer, the size of the output after projection.

##### Raises:


* <b>`TypeError`</b>: if cell is not an RNNCell.
* <b>`ValueError`</b>: if output_size is not positive.


- - -

#### `tf.nn.rnn_cell.OutputProjectionWrapper.output_size` {#OutputProjectionWrapper.output_size}




- - -

#### `tf.nn.rnn_cell.OutputProjectionWrapper.state_size` {#OutputProjectionWrapper.state_size}




- - -

#### `tf.nn.rnn_cell.OutputProjectionWrapper.zero_state(batch_size, dtype)` {#OutputProjectionWrapper.zero_state}

Return zero-filled state tensor(s).

##### Args:


* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
* <b>`dtype`</b>: the data type to use for the state.

##### Returns:

If `state_size` is an int, then the return value is a `2-D` tensor of
shape `[batch_size x state_size]` filled with zeros.

If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.


Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
The most basic RNN cell.
- - -

#### `tf.nn.rnn_cell.BasicRNNCell.__init__(num_units, input_size=None, activation=tanh)` {#BasicRNNCell.__init__}




- - -

#### `tf.nn.rnn_cell.BasicRNNCell.output_size` {#BasicRNNCell.output_size}




- - -

#### `tf.nn.rnn_cell.BasicRNNCell.state_size` {#BasicRNNCell.state_size}




- - -

#### `tf.nn.rnn_cell.BasicRNNCell.zero_state(batch_size, dtype)` {#BasicRNNCell.zero_state}

Return zero-filled state tensor(s).

##### Args:


* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
* <b>`dtype`</b>: the data type to use for the state.

##### Returns:

If `state_size` is an int, then the return value is a `2-D` tensor of
shape `[batch_size x state_size]` filled with zeros.

If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.


Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Operator adding dropout to inputs and outputs of the given cell.
- - -

#### `tf.nn.rnn_cell.DropoutWrapper.__init__(cell, input_keep_prob=1.0, output_keep_prob=1.0, seed=None)` {#DropoutWrapper.__init__}

Create a cell with added input and/or output dropout.

Dropout is never used on the state.

##### Args:


* <b>`cell`</b>: an RNNCell, a projection to output_size is added to it.
* <b>`input_keep_prob`</b>: unit Tensor or float between 0 and 1, input keep
probability; if it is float and 1, no input dropout will be added.
* <b>`output_keep_prob`</b>: unit Tensor or float between 0 and 1, output keep
probability; if it is float and 1, no output dropout will be added.
* <b>`seed`</b>: (optional) integer, the randomness seed.

##### Raises:


* <b>`TypeError`</b>: if cell is not an RNNCell.
* <b>`ValueError`</b>: if keep_prob is not between 0 and 1.


- - -

#### `tf.nn.rnn_cell.DropoutWrapper.output_size` {#DropoutWrapper.output_size}




- - -

#### `tf.nn.rnn_cell.DropoutWrapper.state_size` {#DropoutWrapper.state_size}




- - -

#### `tf.nn.rnn_cell.DropoutWrapper.zero_state(batch_size, dtype)` {#DropoutWrapper.zero_state}

Return zero-filled state tensor(s).

##### Args:


* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
* <b>`dtype`</b>: the data type to use for the state.

##### Returns:

If `state_size` is an int, then the return value is a `2-D` tensor of
shape `[batch_size x state_size]` filled with zeros.

If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.


Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
- - -

#### `tf.nn.rnn_cell.GRUCell.__init__(num_units, input_size=None, activation=tanh)` {#GRUCell.__init__}




- - -

#### `tf.nn.rnn_cell.GRUCell.output_size` {#GRUCell.output_size}




- - -

#### `tf.nn.rnn_cell.GRUCell.state_size` {#GRUCell.state_size}




- - -

#### `tf.nn.rnn_cell.GRUCell.zero_state(batch_size, dtype)` {#GRUCell.zero_state}

Return zero-filled state tensor(s).

##### Args:


* <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
* <b>`dtype`</b>: the data type to use for the state.

##### Returns:

If `state_size` is an int, then the return value is a `2-D` tensor of
shape `[batch_size x state_size]` filled with zeros.

If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.


Loading

0 comments on commit bab4fc4

Please sign in to comment.