Skip to content

Commit

Permalink
Merge pull request tensorflow#960 from lukaszkaiser/ngpu-corrections
Browse files Browse the repository at this point in the history
Corrections and explanations for the updated Neural GPU model.
  • Loading branch information
lukaszkaiser authored Jan 27, 2017
2 parents a298143 + a046ec8 commit a00389b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
16 changes: 14 additions & 2 deletions neural_gpu/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NeuralGPU
Code for the Neural GPU model as described
in [[http://arxiv.org/abs/1511.08228]].
Code for the Neural GPU model described in [[http://arxiv.org/abs/1511.08228]].
The extended version was described in [[https://arxiv.org/abs/1610.08613]].

Requirements:
* TensorFlow (see tensorflow.org for how to install)
Expand Down Expand Up @@ -68,4 +68,16 @@ To interact with a model (experimental, see code) run:
python neural_gpu_trainer.py --problem=bmul --mode=2
```

To train on WMT data, set a larger --nmaps and --vocab_size and avoid curriculum:

```
python neural_gpu_trainer.py --problem=wmt --vocab_size=32768 --nmaps=256
--vec_size=256 --curriculum_seq=1.0 --max_length=60 --data_dir ~/wmt
```

With less memory, try lower batch size, e.g. `--batch_size=4`. With more GPUs
in your system, there will be a batch on every GPU so you can run larger models.
For example, `--batch_size=4 --num_gpus=4 --nmaps=512 --vec_size=512` will
run a large model (512-size) on 4 GPUs, with effective batches of 4*4=16.

Maintained by Lukasz Kaiser (lukaszkaiser)
4 changes: 2 additions & 2 deletions neural_gpu/neural_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def conv_lin(args, suffix, bias_start):
reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
if cutoff > 10:
candidate = tf.tanh_hard(conv_lin(inpts + [reset * mem], "c", 0.0))
candidate = tanh_hard(conv_lin(inpts + [reset * mem], "c", 0.0))
else:
# candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, nmaps, vec_size, niclass, noclass, dropout,

if backward:
adam_lr = 0.005 * self.lr
adam = tf.train.AdamOptimizer(adam_lr, epsilon=2e-4)
adam = tf.train.AdamOptimizer(adam_lr, epsilon=1e-3)

def adam_update(grads):
return adam.apply_gradients(zip(grads, tf.trainable_variables()),
Expand Down
6 changes: 4 additions & 2 deletions neural_gpu/neural_gpu_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
tf.app.flags.DEFINE_float("curriculum_ppx", 9.9, "Move curriculum if ppl < X.")
tf.app.flags.DEFINE_float("curriculum_seq", 0.3, "Move curriculum if seq < X.")
tf.app.flags.DEFINE_float("dropout", 0.0, "Dropout that much.")
tf.app.flags.DEFINE_float("dropout", 0.1, "Dropout that much.")
tf.app.flags.DEFINE_float("grad_noise_scale", 0.0, "Gradient noise scale.")
tf.app.flags.DEFINE_float("max_sampling_rate", 0.1, "Maximal sampling rate.")
tf.app.flags.DEFINE_float("length_norm", 0.0, "Length normalization.")
Expand Down Expand Up @@ -263,7 +263,8 @@ def initialize(sess=None):
data.rev_vocab = rev_fr_vocab
data.print_out("Reading development and training data (limit: %d)."
% FLAGS.max_train_data_size)
dev_set = read_data(en_dev, fr_dev, data.bins)
dev_set = {}
dev_set["wmt"] = read_data(en_dev, fr_dev, data.bins)
def data_read(size, print_out):
read_data_into_global(en_train, fr_train, data.bins, size, print_out)
data_read(50000, False)
Expand Down Expand Up @@ -330,6 +331,7 @@ def job_id_factor(step):
ngpu.CHOOSE_K = FLAGS.soft_mem_size
do_beam_model = FLAGS.train_beam_freq > 0.0001 and FLAGS.beam_size > 1
beam_size = FLAGS.beam_size if FLAGS.mode > 0 and not do_beam_model else 1
beam_size = min(beam_size, FLAGS.beam_size)
beam_model = None
def make_ngpu(cur_beam_size, back):
return ngpu.NeuralGPU(
Expand Down

0 comments on commit a00389b

Please sign in to comment.