Skip to content

Commit

Permalink
Dynamic RNN support
Browse files Browse the repository at this point in the history
- Refactoring of all RNNs models
- Delete dynamic_rnn (dynamic computation directly optional inside any rnn layer)
- Custom RNN cells available (with deeper customization)
  • Loading branch information
aymericdamien committed Jul 1, 2016
1 parent 48d743c commit 4acd614
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 1,090 deletions.
17 changes: 9 additions & 8 deletions docs/templates/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
- [Using HDF5](https://github.com/tflearn/tflearn/blob/master/examples/basics/use_hdf5.py). Use HDF5 to handle large datasets.
- [Using DASK](https://github.com/tflearn/tflearn/blob/master/examples/basics/use_dask.py). Use DASK to handle large datasets.

## Extending Tensorflow
- [Layers](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/layers.py). Use TFLearn layers along with Tensorflow.
- [Trainer](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/trainer.py). Use TFLearn trainer class to train any Tensorflow graph.
- [Built-in Ops](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/builtin_ops.py). Use TFLearn built-in operations along with Tensorflow.
- [Summaries](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/summaries.py). Use TFLearn summarizers along with Tensorflow.
- [Variables](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/variables.py). Use TFLearn variables along with Tensorflow.
## Extending TensorFlow
- [Layers](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/layers.py). Use TFLearn layers along with TensorFlow.
- [Trainer](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/trainer.py). Use TFLearn trainer class to train any TensorFlow graph.
- [Built-in Ops](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/builtin_ops.py). Use TFLearn built-in operations along with TensorFlow.
- [Summaries](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/summaries.py). Use TFLearn summarizers along with TensorFlow.
- [Variables](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/variables.py). Use TFLearn variables along with TensorFlow.

## Computer Vision
- [Multi-layer perceptron](https://github.com/tflearn/tflearn/blob/master/examples/images/dnn.py). A multi-layer perceptron implementation for MNIST classification task.
Expand All @@ -30,8 +30,9 @@
- [Auto Encoder](https://github.com/tflearn/tflearn/blob/master/examples/images/autoencoder.py). An auto encoder applied to MNIST handwritten digits.

## Natural Language Processing
- [Reccurent Network (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm.py). Apply an LSTM to IMDB sentiment dataset classification task.
- [Bi-Directional LSTM](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a bi-directional LSTM to IMDB sentiment dataset classification task.
- [Recurrent Neural Network (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm.py). Apply an LSTM to IMDB sentiment dataset classification task.
- [Bi-Directional RNN (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a bi-directional LSTM to IMDB sentiment dataset classification task.
- [Dynamic RNN (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a dynamic LSTM to classify variable length text from IMDB dataset.
- [City Name Generation](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm_generator_cityname.py). Generates new US-cities name, using LSTM network.
- [Shakespeare Scripts Generation](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm_generator_shakespeare.py). Generates new Shakespeare scripts, using LSTM network.

Expand Down
17 changes: 9 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
- [Using HDF5](https://github.com/tflearn/tflearn/blob/master/examples/basics/use_hdf5.py). Use HDF5 to handle large datasets.
- [Using DASK](https://github.com/tflearn/tflearn/blob/master/examples/basics/use_dask.py). Use DASK to handle large datasets.

## Extending Tensorflow
- [Layers](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/layers.py). Use TFLearn layers along with Tensorflow.
- [Trainer](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/trainer.py). Use TFLearn trainer class to train any Tensorflow graph.
- [Built-in Ops](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/builtin_ops.py). Use TFLearn built-in operations along with Tensorflow.
- [Summaries](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/summaries.py). Use TFLearn summarizers along with Tensorflow.
- [Variables](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/variables.py). Use TFLearn variables along with Tensorflow.
## Extending TensorFlow
- [Layers](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/layers.py). Use TFLearn layers along with TensorFlow.
- [Trainer](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/trainer.py). Use TFLearn trainer class to train any TensorFlow graph.
- [Built-in Ops](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/builtin_ops.py). Use TFLearn built-in operations along with TensorFlow.
- [Summaries](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/summaries.py). Use TFLearn summarizers along with TensorFlow.
- [Variables](https://github.com/tflearn/tflearn/blob/master/examples/extending_tensorflow/variables.py). Use TFLearn variables along with TensorFlow.

## Computer Vision
- [Multi-layer perceptron](https://github.com/tflearn/tflearn/blob/master/examples/images/dnn.py). A multi-layer perceptron implementation for MNIST classification task.
Expand All @@ -30,8 +30,9 @@
- [Auto Encoder](https://github.com/tflearn/tflearn/blob/master/examples/images/autoencoder.py). An auto encoder applied to MNIST handwritten digits.

## Natural Language Processing
- [Reccurent Network (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm.py). Apply an LSTM to IMDB sentiment dataset classification task.
- [Bi-Directional LSTM](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a bi-directional LSTM to IMDB sentiment dataset classification task.
- [Recurrent Neural Network (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm.py). Apply an LSTM to IMDB sentiment dataset classification task.
- [Bi-Directional RNN (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a bi-directional LSTM to IMDB sentiment dataset classification task.
- [Dynamic RNN (LSTM)](https://github.com/tflearn/tflearn/blob/master/examples/nlp/bidirectional_lstm.py). Apply a dynamic LSTM to classify variable length text from IMDB dataset.
- [City Name Generation](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm_generator_cityname.py). Generates new US-cities name, using LSTM network.
- [Shakespeare Scripts Generation](https://github.com/tflearn/tflearn/blob/master/examples/nlp/lstm_generator_shakespeare.py). Generates new Shakespeare scripts, using LSTM network.

Expand Down
6 changes: 2 additions & 4 deletions examples/nlp/bidirectional_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from tflearn.layers.estimator import regression

# IMDB Dataset loading
train, val, test = imdb.load_data(path='imdb.pkl', maxlen=200,
n_words=20000)
train, test, _ = imdb.load_data(path='imdb.pkl', n_words=10000,
valid_portion=0.1)
trainX, trainY = train
valX, valY = val
testX, testY = test

# Data preprocessing
Expand All @@ -40,7 +39,6 @@
testX = pad_sequences(testX, maxlen=200, value=0.)
# Converting labels to binary vectors
trainY = to_categorical(trainY, nb_classes=2)
valY = to_categorical(valY, nb_classes=2)
testY = to_categorical(testY, nb_classes=2)

# Network building
Expand Down
56 changes: 56 additions & 0 deletions examples/nlp/dynamic_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
"""
Simple example using a Dynamic RNN (LSTM) to classify IMDB sentiment dataset.
Dynamic computation are performed over sequences with variable length.
References:
- Long Short Term Memory, Sepp Hochreiter & Jurgen Schmidhuber, Neural
Computation 9(8): 1735-1780, 1997.
- Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng,
and Christopher Potts. (2011). Learning Word Vectors for Sentiment
Analysis. The 49th Annual Meeting of the Association for Computational
Linguistics (ACL 2011).
Links:
- http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
- http://ai.stanford.edu/~amaas/data/sentiment/
"""
from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb

# IMDB Dataset loading
train, test, _ = imdb.load_data(path='imdb.pkl', n_words=10000,
valid_portion=0.1)
trainX, trainY = train
testX, testY = test

# Data preprocessing
# NOTE: Padding is required for dimension consistency. This will pad sequences
# with 0 at the end, until it reaches the max sequence length. 0 is used as a
# masking value by dynamic RNNs in TFLearn; a sequence length will be
# retrieved by counting non zero elements in a sequence. Then dynamic RNN step
# computation is performed according to that length.
trainX = pad_sequences(trainX, maxlen=100, value=0.)
testX = pad_sequences(testX, maxlen=100, value=0.)
# Converting labels to binary vectors
trainY = to_categorical(trainY, nb_classes=2)
testY = to_categorical(testY, nb_classes=2)

# Network building
net = tflearn.input_data([None, 100])
# Masking is not required for embedding, sequence length is computed prior to
# the embedding op and assigned as 'seq_length' attribute to the returned Tensor.
net = tflearn.embedding(net, input_dim=10000, output_dim=128)
net = tflearn.lstm(net, 128, dropout=0.8, dynamic=True)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
loss='categorical_crossentropy')

# Training
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True,
batch_size=32)
19 changes: 9 additions & 10 deletions examples/nlp/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,28 @@
from tflearn.datasets import imdb

# IMDB Dataset loading
train, val, test = imdb.load_data(path='imdb.pkl', maxlen=200,
n_words=20000)
train, test, _ = imdb.load_data(path='imdb.pkl', n_words=10000,
valid_portion=0.1)
trainX, trainY = train
testX, testY = test

# Data preprocessing
# Sequence padding
trainX = pad_sequences(trainX, maxlen=200, value=0.)
testX = pad_sequences(testX, maxlen=200, value=0.)
trainX = pad_sequences(trainX, maxlen=100, value=0.)
testX = pad_sequences(testX, maxlen=100, value=0.)
# Converting labels to binary vectors
trainY = to_categorical(trainY, nb_classes=2)
testY = to_categorical(testY, nb_classes=2)

# Network building
net = tflearn.input_data([None, 200])
net = tflearn.embedding(net, input_dim=20000, output_dim=128)
net = tflearn.input_data([None, 100])
net = tflearn.embedding(net, input_dim=10000, output_dim=128)
net = tflearn.lstm(net, 128, dropout=0.8)
net = tflearn.dropout(net, 0.5)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net, optimizer='adam',
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
loss='categorical_crossentropy')

# Training
model = tflearn.DNN(net, clip_gradients=0., tensorboard_verbose=0)
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True,
batch_size=128)
batch_size=32)
3 changes: 2 additions & 1 deletion tflearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
flatten, activation, fully_connected, single_unit, highway, one_hot_encoding
from .layers.normalization import batch_normalization, local_response_normalization
from .layers.estimator import regression
from .layers.recurrent import lstm, gru, simple_rnn, bidirectional_rnn, dynamic_rnn, RNNCell, BasicLSTMCell, GRUCell, BasicRNNCell, DropoutWrapper
from .layers.recurrent import lstm, gru, simple_rnn, bidirectional_rnn, \
BasicRNNCell, BasicLSTMCell, GRUCell
from .layers.embedding_ops import embedding
from .layers.merge_ops import merge, merge_outputs

Expand Down
4 changes: 2 additions & 2 deletions tflearn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def to_categorical(y, nb_classes):
# =====================


def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre',
truncating='pre', value=0.):
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='post',
truncating='post', value=0.):
""" pad_sequences.
Pad each sequence to the same length: the length of the longest sequence.
Expand Down
3 changes: 2 additions & 1 deletion tflearn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
activation, fully_connected, single_unit
from .normalization import batch_normalization, local_response_normalization
from .estimator import regression
from .recurrent import lstm, gru, simple_rnn, bidirectional_rnn, dynamic_rnn, RNNCell, BasicLSTMCell, GRUCell, BasicRNNCell
from .recurrent import lstm, gru, simple_rnn, bidirectional_rnn, \
BasicRNNCell, BasicLSTMCell, GRUCell
from .embedding_ops import embedding
from .merge_ops import merge, merge_outputs
29 changes: 18 additions & 11 deletions tflearn/layers/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import numpy as np
import tensorflow as tf

from .recurrent import retrieve_seq_length_op
from .. import variables as vs
from .. import utils
from .. import initializations


def embedding(incoming, input_dim, output_dim, weights_init='truncated_normal',
trainable=True, restore=True, reuse=False, scope=None,
name="Embedding"):
def embedding(incoming, input_dim, output_dim, validate_indices=False,
weights_init='truncated_normal', trainable=True, restore=True,
reuse=False, scope=None, name="Embedding"):
""" Embedding.
Embedding layer for a sequence of ids.
Expand All @@ -26,18 +27,23 @@ def embedding(incoming, input_dim, output_dim, weights_init='truncated_normal',
incoming: Incoming 2-D Tensor.
input_dim: list of `int`. Vocabulary size (number of ids).
output_dim: list of `int`. Embedding size.
validate_indices: `bool`. Whether or not to validate gather indices.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share varibales between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Embedding'.
"""

input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 2, "Incoming Tensor shape must be 2-D"
n_inputs = int(np.prod(input_shape[1:]))

W_init = weights_init
if isinstance(weights_init, str):
Expand All @@ -52,13 +58,14 @@ def embedding(incoming, input_dim, output_dim, weights_init='truncated_normal',
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)

inference = tf.cast(incoming, tf.int32)
inference = tf.nn.embedding_lookup(W, inference)
inference = tf.transpose(inference, [1, 0, 2])
inference = tf.reshape(inference, shape=[-1, output_dim])
inference = tf.split(0, n_inputs, inference)
inference = tf.nn.embedding_lookup(W, inference,
validate_indices=validate_indices)

# TODO: easy access those var
# inference.W = W
# inference.scope = scope
inference.W = W
inference.scope = scope
# Embedding doesn't support masking, so we save sequence length prior
# to the lookup. Expand dim to 3d.
shape = [-1] + inference.get_shape().as_list()[1:3] + [1]
inference.seq_length = retrieve_seq_length_op(tf.reshape(incoming, shape))

return inference
Loading

0 comments on commit 4acd614

Please sign in to comment.