forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Opensource code for Deep Transformer with Latent Depth (facebookresea…
…rch#2703) Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Opensource code for Deep Transformer with Latent Depth (https://arxiv.org/pdf/2009.13102.pdf). New features and design choices made: - New feature: allow non-residual block to be weighted by sample z (generated per batch) instead of `x = residual + x`. - Design choice: move `x = residual + x` in transformer_layer.py into a function where the subclass (with latent depth) could overwrite it to `x = residual + z*x`. - New feature: allow TransformerEncoder or TransformerDecoder to have additional logits parameters which will generate the samples z. - Design choice: added subclass LatentTransformerEncoder and LatentTransformerDecoder, which has additional attributes for the logits parameters, and instantiate the corresponding LatentTransformerEncoderLayer and LatentTransformerDecoderLayer. - New feature: allow multilingual_translation task to train with latent depth (results in the paper). - Design choice: - added additional arguments in the multilingual_translation task. - added option for multilingual_transformer to use LatentTransformerEncoder and LatentTransformerDecoder besides standard TransformerEncoder. - added option in multilingual_translation task's `train_step` to generate the samples z and compute the KL (and sparsity) loss per batch. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: facebookresearch#2703 Reviewed By: myleott Differential Revision: D24155059 Pulled By: xianxl fbshipit-source-id: f3e41639429f9664ec5565839709aa857a643668
- Loading branch information
1 parent
3544f5f
commit 573c2f4
Showing
15 changed files
with
672 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Deep Transformers with Latent Depth (Li et al., 2020) | ||
|
||
[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102). | ||
|
||
## Introduction | ||
|
||
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair. | ||
|
||
## Training a multilingual model with latent depth | ||
|
||
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided. | ||
```bash | ||
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" | ||
databin_dir=<path to binarized data> | ||
|
||
python fairseq_cli/train.py ${databin_dir} \ | ||
--user-dir, examples/latent_depth/src \ | ||
--lang-pairs "${lang_pairs_str}" \ | ||
--arch multilingual_transformer_iwslt_de_en \ | ||
--task multilingual_translation_latent_depth \ | ||
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ | ||
--share-encoders \ | ||
--share-decoders \ | ||
--decoder-langtok \ | ||
--share-decoder-input-output-embed \ | ||
--dropout 0.3 --attention-dropout 0.3 \ | ||
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ | ||
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ | ||
--max-tokens 4096 --update-freq 1 \ | ||
--lr 0.0015 \ | ||
--clip-norm 1.0 \ | ||
--seed 2 \ | ||
--ddp-backend=no_c10d \ | ||
--encoder-layers 12 \ | ||
--decoder-layers 24 \ | ||
--decoder-latent-layer \ | ||
--sparsity-weight 0.1 \ | ||
--anneal-updates 5000 \ | ||
--soft-update 500 \ | ||
--target-layers 12 \ | ||
--share-weight 0.1 | ||
``` | ||
## Inference command | ||
|
||
```bash | ||
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur" | ||
databin_dir=<path to binarized data> | ||
model_path=<path to checkpoint> | ||
src_lang=<source language to translate from> | ||
tgt_lang=<target language to translate to> | ||
gen_data=<name of data split, e.g. valid, test, etc> | ||
|
||
python fairseq_cli/generate.py ${databin_dir} \ | ||
--path ${model_path} \ | ||
--task multilingual_translation_latent_depth \ | ||
--decoder-latent-layer \ | ||
--lang-pairs "${lang_pairs_str}" \ | ||
-s ${src_lang} -t ${tgt_lang} \ | ||
--gen-subset $gen_data \ | ||
--scoring sacrebleu \ | ||
--remove-bpe 'sentencepiece' \ | ||
--lenpen 1.0 \ | ||
--beam 5 \ | ||
--decoder-langtok \ | ||
--max-tokens 4096 | ||
``` | ||
|
||
|
||
## Citation | ||
```bibtex | ||
@article{li2020deep, | ||
title={Deep Transformers with Latent Depth}, | ||
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang}, | ||
journal={arXiv preprint arXiv:2009.13102}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .models import latent_multilingual_transformer # noqa | ||
from .modules import latent_layers # noqa | ||
from .loss import latent_depth # noqa | ||
from . import multilingual_translation_latent_depth # noqa |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import math | ||
from torch.nn.modules.loss import _Loss | ||
|
||
|
||
class LatentLayersKLLoss(_Loss): | ||
def __init__(self, args): | ||
super().__init__() | ||
self.args = args | ||
|
||
def forward(self, layer_samples, lang_idx, update_num, sample_size): | ||
prior = self.args.prior | ||
samples = layer_samples[lang_idx] | ||
eps = 1e-7 | ||
if prior == "uniform": | ||
# uniform prior | ||
kl_loss = (samples * ( | ||
torch.log(samples + eps) - math.log(0.5) | ||
)).sum(-1) | ||
elif prior == "agged_posterior": | ||
# aggregated posterior | ||
y_t = torch.stack([x.detach() for x in layer_samples], dim=0) | ||
agged_q = torch.sum(y_t, dim=0) | ||
row_norm = agged_q.sum(-1) | ||
normed_agg_q = agged_q / row_norm | ||
kl_loss = (samples * ( | ||
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1) | ||
else: | ||
raise NotImplementedError("The specified prior is not implemented.") | ||
|
||
# normalized by number of layers | ||
kl_loss /= layer_samples[0].size()[0] | ||
kl_weight = min( | ||
self.args.sparsity_weight, | ||
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates | ||
) | ||
kl_loss *= kl_weight * sample_size | ||
return kl_loss | ||
|
||
|
||
class LatentLayersSparsityLoss(_Loss): | ||
def __init__(self, args): | ||
super().__init__() | ||
self.args = args | ||
|
||
def is_valid(self, update_num): | ||
if self.args.target_layers <= 0: | ||
return False | ||
return update_num > (self.args.soft_update + self.args.anneal_updates) | ||
|
||
def forward(self, layer_samples_list, update_num, sample_size): | ||
batch_loss = 0 | ||
share_loss = 0 | ||
global_sparsity_loss = 0 | ||
layer_samples = torch.stack(layer_samples_list, dim=0) | ||
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and | ||
update_num > (self.args.soft_update + self.args.anneal_updates)): | ||
# anneal sparsity weight | ||
if update_num < (self.args.anneal_updates + self.args.soft_update): | ||
weight_anneal = 0 | ||
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update): | ||
weight_anneal = ( | ||
(update_num - self.args.soft_update - self.args.anneal_updates) | ||
* self.args.share_weight / self.args.anneal_updates | ||
) | ||
else: | ||
weight_anneal = 1 | ||
# compute ratio among languages | ||
layer_utilization = torch.sum(layer_samples, dim=0) | ||
layer_utilization /= layer_samples.size()[0] | ||
if self.args.share_weight > 0: | ||
# encouraging sharing across languages | ||
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0) | ||
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss | ||
if self.args.target_layers > 0: | ||
# computed expected number of layers selected | ||
expeted_layers = sum(layer_utilization) | ||
# compute l2 loss wrt target number of layers | ||
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2 | ||
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss | ||
return batch_loss |
Empty file.
60 changes: 60 additions & 0 deletions
60
examples/latent_depth/src/models/latent_multilingual_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from fairseq.models import ( | ||
register_model, | ||
register_model_architecture, | ||
) | ||
from fairseq.models.transformer import ( | ||
base_architecture, | ||
TransformerEncoder, | ||
TransformerDecoder, | ||
) | ||
from fairseq.models.multilingual_transformer import MultilingualTransformerModel | ||
|
||
from .latent_transformer import ( | ||
LatentTransformerEncoder, | ||
LatentTransformerDecoder, | ||
) | ||
|
||
|
||
@register_model('latent_multilingual_transformer') | ||
class LatentMultilingualTransformerModel(MultilingualTransformerModel): | ||
"""A variant of standard multilingual Transformer models which encoder and/or | ||
decoders supports latent depth, as is in "Deep Transformer with Latent Depth" | ||
(https://arxiv.org/abs/2009.13102). | ||
""" | ||
@classmethod | ||
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): | ||
if is_encoder: | ||
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer: | ||
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs)) | ||
else: | ||
return TransformerEncoder(args, lang_dict, embed_tokens) | ||
else: | ||
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer: | ||
return LatentTransformerDecoder( | ||
args, lang_dict, embed_tokens, num_logits=len(langs) | ||
) | ||
else: | ||
return TransformerDecoder(args, lang_dict, embed_tokens) | ||
|
||
|
||
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer') | ||
def latent_multilingual_architecture(args): | ||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) | ||
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) | ||
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) | ||
args.encoder_layers = getattr(args, 'encoder_layers', 12) | ||
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) | ||
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) | ||
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) | ||
args.decoder_layers = getattr(args, 'decoder_layers', 24) | ||
args.share_encoders = getattr(args, 'share_encoders', True) | ||
args.share_decoders = getattr(args, 'share_decoders', True) | ||
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True) | ||
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True) | ||
|
||
base_architecture(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
import torch.nn as nn | ||
from fairseq.models.fairseq_encoder import EncoderOut | ||
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder | ||
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer | ||
from ..modules.latent_layers import LayerSelect | ||
from torch import Tensor | ||
|
||
|
||
class LatentTransformerEncoder(TransformerEncoder): | ||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in | ||
TransformerEncoder. | ||
""" | ||
def __init__(self, args, dictionary, embed_tokens, num_logits=1): | ||
self.num_logits = num_logits | ||
self.num_layers = args.encoder_layers | ||
super().__init__(args, dictionary, embed_tokens) | ||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) | ||
self.lang_idx = None | ||
self.layers = nn.ModuleList([ | ||
self._build_encoder_layer(args, idx) | ||
for idx in range(args.encoder_layers) | ||
]) | ||
|
||
def set_lang_idx(self, lang_idx): | ||
self.lang_idx = lang_idx | ||
|
||
def _build_encoder_layer(self, args, idx=None): | ||
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select) | ||
|
||
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): | ||
self.layer_select.sample(self.lang_idx) | ||
return super().forward(src_tokens, src_lengths, return_all_hiddens) | ||
|
||
|
||
class LatentTransformerEncoderLayer(TransformerEncoderLayer): | ||
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli | ||
or Gumbel Signmoid samples. | ||
Args: | ||
args (argparse.Namespace): parsed command-line arguments from standard | ||
TransformerEncoderLayer. | ||
idx (int): layer index (used to retrieve samples). | ||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits | ||
parameters and sampling method. | ||
""" | ||
def __init__(self, args, idx, layer_select=None): | ||
super().__init__(args) | ||
self.idx = idx | ||
self.layer_select = layer_select | ||
|
||
def residual_connection(self, x, residual): | ||
return residual + x * self.layer_select(self.idx) | ||
|
||
|
||
class LatentTransformerDecoder(TransformerDecoder): | ||
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in | ||
TransformerDecoder. | ||
""" | ||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1): | ||
self.num_logits = num_logits | ||
self.num_layers = args.decoder_layers | ||
super().__init__( | ||
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn | ||
) | ||
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args) | ||
self.lang_idx = None | ||
self.layers = nn.ModuleList([ | ||
self._build_decoder_layer(args, no_encoder_attn, idx) | ||
for idx in range(args.decoder_layers) | ||
]) | ||
|
||
def set_lang_idx(self, lang_idx): | ||
self.lang_idx = lang_idx | ||
|
||
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None): | ||
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn) | ||
|
||
def forward( | ||
self, | ||
prev_output_tokens, | ||
encoder_out: Optional[EncoderOut] = None, | ||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | ||
features_only: bool = False, | ||
alignment_layer: Optional[int] = None, | ||
alignment_heads: Optional[int] = None, | ||
src_lengths: Optional[Any] = None, | ||
return_all_hiddens: bool = False, | ||
): | ||
self.layer_select.sample(self.lang_idx) | ||
return super().forward( | ||
prev_output_tokens=prev_output_tokens, | ||
encoder_out=encoder_out, | ||
incremental_state=incremental_state, | ||
features_only=features_only, | ||
alignment_layer=alignment_layer, | ||
src_lengths=src_lengths, | ||
return_all_hiddens=return_all_hiddens, | ||
) | ||
|
||
|
||
class LatentTransformerDecoderLayer(TransformerDecoderLayer): | ||
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli | ||
or Gumbel Signmoid samples. | ||
Args: | ||
args (argparse.Namespace): parsed command-line arguments from standard | ||
TransformerDecoderLayer. | ||
idx (int): layer index (used to retrieve samples). | ||
layer_select (LayerSelect, optional): instance of LayerSelect module with logits | ||
parameters and sampling method. | ||
no_encoder_attn (bool, optional): whether to attend to encoder outputs | ||
(default: False). | ||
""" | ||
def __init__( | ||
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | ||
): | ||
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) | ||
self.idx = idx | ||
self.layer_select = layer_select | ||
|
||
def residual_connection(self, x, residual): | ||
return residual + x * self.layer_select(self.idx) |
Empty file.
Oops, something went wrong.