Skip to content

Commit

Permalink
[Sockeye 1.10.0] Update to MXNet 0.12 (awslabs#173)
Browse files Browse the repository at this point in the history
 - Updated MXNet dependency to 0.12 (w/ MKL support by default).
 - Changed `--smoothed-cross-entropy-alpha` to `--label-smoothing`.
 Label smoothing should now require significantly less memory due to its addition to MXNet's `SoftmaxOutput` operator.
 - `--weight-normalization` now applies not only to convolutional weight matrices, but to output layers of all decoders.
 It is also independent of weight tying.
 - Transformers now use `--embed-dropout`. Before they were using `--transformer-dropout-prepost` for this.
 - Transformers now scale their embedding vectors before adding fixed positional embeddings.
 This turns out to be crucial for effective learning.
 - `.param` files now use 5 digit identifiers to reduce risk of overflowing with many checkpoints.

### Added
 - Added CUDA 9.0 requirements file.
 - `--loss-normalization-type`. Added a new flag to control loss normalization. New default is to normalize
 by the number of valid, non-PAD tokens instead of the batch size.
 - `--weight-init-xavier-factor-type`. Added new flag to control Xavier factor type when `--weight-init=xavier`.
 - `--embed-weight-init`. Added new flag for initialization of embeddings matrices.
 
### Removed
 - `--smoothed-cross-entropy-alpha` argument. See above.
 - `--normalize-loss` argument. See above.
  • Loading branch information
fhieber authored Nov 1, 2017
1 parent 2dcafc2 commit d436ae8
Show file tree
Hide file tree
Showing 31 changed files with 488 additions and 470 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ python:
- "3.6"

install:
- sudo apt-get -y update && sudo apt-get install -y libgfortran3
- pip install -r requirements.txt
- pip install -r requirements.dev.txt
- pip install -r requirements.docs.txt
Expand Down
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,29 @@ Note that Sockeye has checks in place to not translate with an old model that wa

For each item we will potentially have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.10.0]
### Changed
- Updated MXNet dependency to 0.12 (w/ MKL support by default).
- Changed `--smoothed-cross-entropy-alpha` to `--label-smoothing`.
Label smoothing should now require significantly less memory due to its addition to MXNet's `SoftmaxOutput` operator.
- `--weight-normalization` now applies not only to convolutional weight matrices, but to output layers of all decoders.
It is also independent of weight tying.
- Transformers now use `--embed-dropout`. Before they were using `--transformer-dropout-prepost` for this.
- Transformers now scale their embedding vectors before adding fixed positional embeddings.
This turns out to be crucial for effective learning.
- `.param` files now use 5 digit identifiers to reduce risk of overflowing with many checkpoints.

### Added
- Added CUDA 9.0 requirements file.
- `--loss-normalization-type`. Added a new flag to control loss normalization. New default is to normalize
by the number of valid, non-PAD tokens instead of the batch size.
- `--weight-init-xavier-factor-type`. Added new flag to control Xavier factor type when `--weight-init=xavier`.
- `--embed-weight-init`. Added new flag for initialization of embeddings matrices.

### Removed
- `--smoothed-cross-entropy-alpha` argument. See above.
- `--normalize-loss` argument. See above.

## [1.9.0]
### Added
- Batch decoding. New options for the translate CLI: ``--batch-size`` and ``--chunk-size``. Translator.translate()
Expand Down
32 changes: 12 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Recent developments and changes are tracked in our [changelog](https://github.co

Sockeye requires:
- **Python3**
- [MXNet-0.10.0](https://github.com/dmlc/mxnet/tree/v0.10.0)
- [MXNet-0.12.0](https://github.com/dmlc/mxnet/tree/v0.12.0)
- numpy

## Installation
Expand Down Expand Up @@ -50,19 +50,13 @@ remaining instructions to work you will need to use `python3` instead of `python

If you want to run sockeye on a GPU you need to make sure your version of Apache MXNet Incubating contains the GPU
bindings.
Depending on your version of CUDA you can do this by running the following for CUDA 8.0:

```bash
> wget https://raw.githubusercontent.com/awslabs/sockeye/master/requirements.gpu-cu80.txt
> pip install sockeye --no-deps -r requirements.gpu-cu80.txt
> rm requirements.gpu-cu80.txt
```
or the following for CUDA 7.5:
Depending on your version of CUDA, you can do this by running the following:
```bash
> wget https://raw.githubusercontent.com/awslabs/sockeye/master/requirements.gpu-cu75.txt
> pip install sockeye --no-deps -r requirements.gpu-cu75.txt
> rm requirements.gpu-cu75.txt
> wget https://raw.githubusercontent.com/awslabs/sockeye/master/requirements.gpu-cu${CUDA_VERSION}.txt
> pip install sockeye --no-deps -r requirements.gpu-cu${CUDA_VERSION}.txt
> rm requirements.gpu-cu${CUDA_VERSION}.txt
```
where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), or `90` (9.0).

### Or: From Source

Expand All @@ -78,15 +72,12 @@ after cloning the repository from git.

If you want to run sockeye on a GPU you need to make sure your version of Apache MXNet
Incubating contains the GPU bindings. Depending on your version of CUDA you can do this by
running the following for CUDA 8.0:
running the following:

```bash
> python setup.py install -r requirements.gpu-cu80.txt
```
or the following for CUDA 7.5:
```bash
> python setup.py install -r requirements.gpu-cu75.txt
> python setup.py install -r requirements.gpu-cu${CUDA_VERSION}.txt
```
where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), or `90` (9.0).

### Optional dependencies
In order to track learning curves during training you can optionally install dmlc's tensorboard fork
Expand Down Expand Up @@ -115,7 +106,8 @@ directly. For example *sockeye-train* can also be invoked as

In order to train your first Neural Machine Translation model you will need two sets of parallel files: one for training
and one for validation. The latter will be used for computing various metrics during training.
Each set should consist of two files: one with source sentences and one with target sentences (translations). Both files should have the same number of lines, each line containing a single
Each set should consist of two files: one with source sentences and one with target sentences (translations).
Both files should have the same number of lines, each line containing a single
sentence. Each sentence should be a whitespace delimited list of tokens.

Say you wanted to train a RNN German-to-English translation model, then you would call sockeye like this:
Expand All @@ -129,7 +121,7 @@ Say you wanted to train a RNN German-to-English translation model, then you woul
```

After training the directory *<model_dir>* will contain all model artifacts such as parameters and model
configuration.
configuration. The default setting is to train a 1-layer LSTM model with attention.


### Translate
Expand Down
2 changes: 1 addition & 1 deletion requirements.gpu-cu75.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pyyaml
mxnet-cu75==0.10.0
mxnet-cu75mkl==0.12.0
numpy>=1.12
2 changes: 1 addition & 1 deletion requirements.gpu-cu80.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pyyaml
mxnet-cu80==0.10.0
mxnet-cu80mkl==0.12.0
numpy>=1.12
3 changes: 3 additions & 0 deletions requirements.gpu-cu90.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pyyaml
mxnet-cu90mkl==0.12.0
numpy>=1.12
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pyyaml
mxnet==0.10.0
mxnet-mkl==0.12.0
numpy>=1.12
43 changes: 27 additions & 16 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def add_model_parameters(params):
C.LNGLSTM_TYPE))

model_params.add_argument('--weight-normalization', action="store_true",
help="Adds weight normalization to all convolutional weight matrices and the "
"transformation matrix to the output vocab in the convolutional decoder.")
help="Adds weight normalization to decoder output layers "
"(and all convolutional weight matrices for CNN decoders). Default: %(default)s.")


def add_training_args(params):
Expand Down Expand Up @@ -506,17 +506,17 @@ def add_training_args(params):

train_params.add_argument('--loss',
default=C.CROSS_ENTROPY,
choices=[C.CROSS_ENTROPY, C.SMOOTHED_CROSS_ENTROPY],
choices=[C.CROSS_ENTROPY],
help='Loss to optimize. Default: %(default)s.')
train_params.add_argument('--smoothed-cross-entropy-alpha',
default=0.3,
train_params.add_argument('--label-smoothing',
default=0.0,
type=float,
help='Smoothing value for smoothed-cross-entropy loss. Default: %(default)s.')
train_params.add_argument('--normalize-loss',
default=False,
action="store_true",
help='If turned on we normalize the loss by dividing by the number of non-PAD tokens.'
'If turned off the loss is only normalized by the number of sentences in a batch.')
help='Smoothing constant for label smoothing. Default: %(default)s.')
train_params.add_argument('--loss-normalization-type',
default=C.LOSS_NORM_VALID,
choices=[C.LOSS_NORM_VALID, C.LOSS_NORM_BATCH],
help='How to normalize the loss. By default we normalize by the number '
'of valid/non-PAD tokens (%s)' % C.LOSS_NORM_VALID)

train_params.add_argument('--metrics',
nargs='+',
Expand Down Expand Up @@ -618,12 +618,23 @@ def add_training_args(params):
type=str,
default=C.INIT_XAVIER,
choices=C.INIT_TYPES,
help='Type of weight initialization. Default: %(default)s.')
help='Type of base weight initialization. Default: %(default)s.')
train_params.add_argument('--weight-init-scale',
type=float,
default=0.04,
help='Weight initialization scale (currently only applies to uniform initialization). '
default=2.34,
help='Weight initialization scale. Applies to uniform (scale) and xavier (magnitude). '
'Default: %(default)s.')
train_params.add_argument('--weight-init-xavier-factor-type',
type=str,
default='in',
choices=['in', 'out', 'avg'],
help='Xavier factor type. Default: %(default)s.')
train_params.add_argument('--embed-weight-init',
type=str,
default=C.EMBED_INIT_DEFAULT,
choices=C.EMBED_INIT_TYPES,
help='Type of embedding matrix weight initialization. If normal, initializes embedding '
'weights using a normal distribution with std=vocab_size. Default: %(default)s.')
train_params.add_argument('--initial-learning-rate',
type=float,
default=0.0003,
Expand Down Expand Up @@ -757,8 +768,8 @@ def add_inference_args(params):
decode_params.add_argument('--batch-size',
type=int_greater_or_equal(1),
default=1,
help='Batch size during decoding. Determines how many sentences are translated simultaneously.'
'Default: %(default)s.')
help='Batch size during decoding. Determines how many sentences are translated '
'simultaneously. Default: %(default)s.')
decode_params.add_argument('--chunk-size',
type=int_greater_or_equal(1),
default=1,
Expand Down
20 changes: 16 additions & 4 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,22 @@
LEARNED_POSITIONAL_EMBEDDING = "learned"
POSITIONAL_EMBEDDING_TYPES = [NO_POSITIONAL_EMBEDDING, FIXED_POSITIONAL_EMBEDDING, LEARNED_POSITIONAL_EMBEDDING]


DEFAULT_INIT_PATTERN = ".*"

# init types
INIT_XAVIER='xavier'
INIT_UNIFORM='uniform'
INIT_TYPES=[INIT_XAVIER, INIT_UNIFORM]

# Embedding init types
EMBED_INIT_PATTERN = '(%s|%s|%s)weight' % (SOURCE_EMBEDDING_PREFIX, TARGET_EMBEDDING_PREFIX, SHARED_EMBEDDING_PREFIX)
EMBED_INIT_DEFAULT = 'default'
EMBED_INIT_NORMAL = 'normal'
EMBED_INIT_TYPES = [EMBED_INIT_DEFAULT, EMBED_INIT_NORMAL]

# RNN init types
RNN_INIT_PATTERN = ".*h2h.*"
RNN_INIT_ORTHOGONAL = 'orthogonal'
RNN_INIT_ORTHOGONAL_STACKED = 'orthogonal_stacked'
# use the default initializer used also for all other weights
Expand Down Expand Up @@ -168,9 +178,9 @@
VOCAB_TRG_NAME = "vocab.trg"
VOCAB_ENCODING = "utf-8"
PARAMS_PREFIX = "params."
PARAMS_NAME = PARAMS_PREFIX + "%04d"
PARAMS_NAME = PARAMS_PREFIX + "%05d"
PARAMS_BEST_NAME = "params.best"
DECODE_OUT_NAME = "decode.output.%04d"
DECODE_OUT_NAME = "decode.output.%05d"
DECODE_IN_NAME = "decode.source"
DECODE_REF_NAME = "decode.target"
SYMBOL_NAME = "symbol" + JSON_SUFFIX
Expand Down Expand Up @@ -268,9 +278,11 @@
METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, PERPLEXITY: False}
METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, PERPLEXITY: np.inf}

# loss names
# loss
CROSS_ENTROPY = 'cross-entropy'
SMOOTHED_CROSS_ENTROPY = 'smoothed-cross-entropy'

LOSS_NORM_BATCH = 'batch'
LOSS_NORM_VALID = "valid"

TARGET_MAX_LENGTH_FACTOR = 2
DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH = 2
1 change: 1 addition & 0 deletions sockeye/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _post_convolution(self, data_conv):
if self.config.act_type == C.GLU:
# GLU
# two times: (batch_size, num_hidden, seq_len)
# pylint: disable=unbalanced-tuple-unpacking
gate_a, gate_b = mx.sym.split(data_conv, num_outputs=2, axis=1)
# (batch_size, num_hidden, seq_len)
block_output = mx.sym.broadcast_mul(gate_a,
Expand Down
46 changes: 14 additions & 32 deletions sockeye/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
prev_coverage: mx.sym.Symbol):
"""
:param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len).
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
prev_coverage: mx.sym.Symbol):
"""
:param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len).
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""
Expand Down Expand Up @@ -164,7 +164,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
prev_coverage: mx.sym.Symbol):
"""
:param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len).
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""
Expand Down Expand Up @@ -240,61 +240,43 @@ def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len
:return: Coverage callable.
"""

# (batch_size * seq_len, coverage_hidden_num)
source_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=source,
shape=(-3, -1),
name="%sflat_source" % self.prefix),
# (batch_size, seq_len, coverage_hidden_num)
source_hidden = mx.sym.FullyConnected(data=source,
weight=self.cov_e2h_weight,
no_bias=True,
num_hidden=self.num_hidden,
flatten=False,
name="%ssource_hidden_fc" % self.prefix)

# (batch_size, seq_len, coverage_hidden_num)
source_hidden = mx.sym.reshape(source_hidden,
shape=(-1, source_seq_len, self.num_hidden),
name="%ssource_hidden" % self.prefix)

def update_coverage(prev_hidden: mx.sym.Symbol,
attention_prob_scores: mx.sym.Symbol,
prev_coverage: mx.sym.Symbol):
"""
:param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1).
:param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len).
:param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
:return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
"""

# (batch_size * seq_len, coverage_hidden_num)
coverage_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(data=prev_coverage,
shape=(-3, -1),
name="%sflat_previous" % self.prefix),
# (batch_size, seq_len, coverage_hidden_num)
coverage_hidden = mx.sym.FullyConnected(data=prev_coverage,
weight=self.cov_prev2h_weight,
no_bias=True,
num_hidden=self.num_hidden,
flatten=False,
name="%sprevious_hidden_fc" % self.prefix)

# (batch_size, source_seq_len, coverage_hidden_num)
coverage_hidden = mx.sym.reshape(coverage_hidden,
shape=(-1, source_seq_len, self.num_hidden),
name="%sprevious_hidden" % self.prefix)

# (batch_size, source_seq_len, 1)
attention_prob_score = mx.sym.expand_dims(attention_prob_scores, axis=2)
attention_prob_scores = mx.sym.expand_dims(attention_prob_scores, axis=2)

# (batch_size * source_seq_len, coverage_num_hidden)
attention_hidden = mx.sym.FullyConnected(data=mx.sym.reshape(attention_prob_score,
shape=(-3, 0),
name="%sreshape_att_probs" % self.prefix),
# (batch_size, source_seq_len, coverage_num_hidden)
attention_hidden = mx.sym.FullyConnected(data=attention_prob_scores,
weight=self.cov_a2h_weight,
no_bias=True,
num_hidden=self.num_hidden,
flatten=False,
name="%sattention_fc" % self.prefix)

# (batch_size, source_seq_len, coverage_num_hidden)
attention_hidden = mx.sym.reshape(attention_hidden,
shape=(-1, source_seq_len, self.num_hidden),
name="%sreshape_att" % self.prefix)

# (batch_size, coverage_num_hidden)
prev_hidden = mx.sym.FullyConnected(data=prev_hidden, weight=self.cov_dec2h_weight, no_bias=True,
num_hidden=self.num_hidden, name="%sdecoder_hidden")
Expand Down
Loading

0 comments on commit d436ae8

Please sign in to comment.