Skip to content

Commit

Permalink
Opensource code for Deep Transformer with Latent Depth (facebookresea…
Browse files Browse the repository at this point in the history
…rch#2703)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Opensource code for Deep Transformer with Latent Depth (https://arxiv.org/pdf/2009.13102.pdf).

New features and design choices made:

- New feature: allow non-residual block to be weighted by sample z (generated per batch) instead of `x = residual + x`.
- Design choice: move  `x = residual + x` in transformer_layer.py into a function where the subclass (with latent depth) could overwrite it to `x = residual + z*x`.

- New feature: allow TransformerEncoder or TransformerDecoder to have additional logits parameters which will generate the samples z.
- Design choice: added subclass LatentTransformerEncoder and LatentTransformerDecoder, which has additional attributes for the logits parameters, and instantiate the corresponding LatentTransformerEncoderLayer and LatentTransformerDecoderLayer.

- New feature: allow multilingual_translation task to train with latent depth (results in the paper).
- Design choice:
  - added additional arguments in the multilingual_translation task.
  - added option for multilingual_transformer to use LatentTransformerEncoder and LatentTransformerDecoder besides standard TransformerEncoder.
  - added option in multilingual_translation task's `train_step` to generate the samples z and compute the KL (and sparsity) loss per batch.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: facebookresearch#2703

Reviewed By: myleott

Differential Revision: D24155059

Pulled By: xianxl

fbshipit-source-id: f3e41639429f9664ec5565839709aa857a643668
  • Loading branch information
Xian Li authored and facebook-github-bot committed Oct 15, 2020
1 parent 3544f5f commit 573c2f4
Show file tree
Hide file tree
Showing 15 changed files with 672 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ We provide reference implementations of various sequence modeling papers:
- [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
- [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
Expand All @@ -55,6 +56,7 @@ We provide reference implementations of various sequence modeling papers:

### What's New:

- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
- October 2020: [Added CRISS models and code](examples/criss/README.md)
- September 2020: [Added Linformer code](examples/linformer/README.md)
- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
Expand Down
77 changes: 77 additions & 0 deletions examples/latent_depth/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Deep Transformers with Latent Depth (Li et al., 2020)

[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102).

## Introduction

We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.

## Training a multilingual model with latent depth

Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>

python fairseq_cli/train.py ${databin_dir} \
--user-dir, examples/latent_depth/src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=no_c10d \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command

```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>

python fairseq_cli/generate.py ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```


## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```
9 changes: 9 additions & 0 deletions examples/latent_depth/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
from .loss import latent_depth # noqa
from . import multilingual_translation_latent_depth # noqa
Empty file.
86 changes: 86 additions & 0 deletions examples/latent_depth/src/loss/latent_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import math
from torch.nn.modules.loss import _Loss


class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args

def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (
torch.log(samples + eps) - math.log(0.5)
)).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (samples * (
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")

# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
)
kl_loss *= kl_weight * sample_size
return kl_loss


class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args

def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)

def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
update_num > (self.args.soft_update + self.args.anneal_updates)):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight / self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
return batch_loss
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models.multilingual_transformer import MultilingualTransformerModel

from .latent_transformer import (
LatentTransformerEncoder,
LatentTransformerDecoder,
)


@register_model('latent_multilingual_transformer')
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)


@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 24)
args.share_encoders = getattr(args, 'share_encoders', True)
args.share_decoders = getattr(args, 'share_decoders', True)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)

base_architecture(args)
130 changes: 130 additions & 0 deletions examples/latent_depth/src/models/latent_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.latent_layers import LayerSelect
from torch import Tensor


class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_encoder_layer(args, idx)
for idx in range(args.encoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)

def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)


class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)


class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
])

def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx

def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)

def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)


class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select

def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
Empty file.
Loading

0 comments on commit 573c2f4

Please sign in to comment.