Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
  too many of them actually ^^

## What does this PR do?
This is a rewrite of fairinternal/fairseq-py#1538 following the discussion there, and taking into account the proposed fairinternal/fairseq-py#1560 from Myle.
it brings online backtranslation to fairseq.
It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: fairinternal/fairseq-py#1614

Reviewed By: myleott

Differential Revision: D27157296

Pulled By: gwenzek

fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21
  • Loading branch information
gwenzek authored and facebook-github-bot committed Mar 30, 2021
1 parent 7dafb05 commit c2e8904
Show file tree
Hide file tree
Showing 16 changed files with 1,472 additions and 28 deletions.
2 changes: 2 additions & 0 deletions fairseq/data/noising.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def __init__(
**kwargs,
)
)
self.sizes = src_dataset.sizes


def __getitem__(self, index):
"""
Expand Down
2 changes: 1 addition & 1 deletion fairseq/data/round_robin_zip_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _deep_until_language_pair(dataset):
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
)
# Since we are modifiying in place the _ordered_indices,
# Since we are modifying in place the _ordered_indices,
# it's not possible anymore to return valid ignored indices.
# Hopefully the extra debug information print above should be enough to debug.
# Ideally we would receive ignore_invalid_inputs so that we could have
Expand Down
3 changes: 3 additions & 0 deletions fairseq/data/transform_eos_lang_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __len__(self):
def collater(self, samples, **extra_args):
samples = self.dataset.collater(samples, **extra_args)

if 'net_input' not in samples:
return samples

if self.new_src_eos is not None:
if self.dataset.left_pad_source:
assert (
Expand Down
1 change: 1 addition & 0 deletions fairseq/models/roberta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .hub_interface import * # noqa
from .model import * # noqa
from .enc_dec import * # noqa
from .model_camembert import * # noqa
from .model_gottbert import * # noqa
from .model_xlmr import * # noqa
192 changes: 192 additions & 0 deletions fairseq/models/roberta/enc_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import argparse
import logging

import torch.nn as nn
import fairseq.checkpoint_utils
from fairseq.models import (
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import TransformerDecoder
from fairseq.models.roberta import model as roberta

logger = logging.getLogger(__name__)


@register_model("roberta_enc_dec")
class RobertaEncDecModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
parser.add_argument(
"--pretrained-mlm-checkpoint",
default=None,
type=str,
metavar="PRETRAINED",
help="path to pretrained mlm checkpoint",
)
parser.add_argument(
"--pretrained-decoder", action="store_true", help="reload decoder"
)
parser.add_argument(
"--hack-layernorm-embedding",
action="store_true",
help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--share-all-embeddings",
action="store_true",
help="share encoder, decoder and output embeddings"
" (requires shared dictionary and embed dim)",
)

@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""

# make sure all arguments are present
base_enc_dec_architecture(args)
if args.pretrained_mlm_checkpoint:
arg_overrides = None
if args.hack_layernorm_embedding:
arg_overrides = {"layernorm_embedding": False}
loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides
)
([roberta_enc], _cfg, _task) = loaded
else:
# Do we need to edit untie_weights here ?
share_in_out = (
args.share_decoder_input_output_embed or args.share_all_embeddings
)
args.untie_weights_roberta = not share_in_out
if args.hack_layernorm_embedding:
args.layernorm_embedding = False
args.encoder_normalize_before = False
roberta_enc = roberta.RobertaModel.build_model(args, task)

return cls.from_roberta(roberta_enc, args, task.source_dictionary)

@staticmethod
def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
encoder = roberta_enc.encoder.sentence_encoder
vocab_size, embed_dim = encoder.embed_tokens.weight.shape

if args.share_all_embeddings:
lm_head = roberta_enc.encoder.lm_head
assert encoder.embed_tokens.weight is lm_head.weight, (
"Can't use --share-all-embeddings with a model "
"that was pretraiend with --untie-weights-roberta_enc"
)
else:
lm_head = roberta.RobertaLMHead(
embed_dim, vocab_size, roberta_enc.args.activation_fn
)

dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
if args.share_all_embeddings or args.share_decoder_input_output_embed:
# Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
dec_embs.weight = lm_head.weight

decoder = TransformerDecoder(
RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
dictionary,
dec_embs,
no_encoder_attn=False,
output_projection=lm_head,
)
if getattr(args, "pretrained_decoder", False):
decoder_dict = encoder.state_dict()

# TODO: hide setting "encoder_attn" layers behind a flag.
for k, w in list(decoder_dict.items()):
if ".self_attn" in k:
k_enc_attn = k.replace(".self_attn", ".encoder_attn")
decoder_dict[k_enc_attn] = w.detach().clone()

for k, w in lm_head.state_dict().items():
decoder_dict["output_projection." + k] = w

missing_keys, unexpected_keys = decoder.load_state_dict(
decoder_dict, strict=False
)
# missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
assert not missing_keys and not unexpected_keys, (
"Failed to load state dict. "
f"Missing keys: {missing_keys}. "
f"Unexpected keys: {unexpected_keys}."
)

if args.share_all_embeddings:
assert decoder.output_projection.weight is decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
elif args.share_decoder_input_output_embed:
assert decoder.output_projection.weight is decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
else:
assert decoder.output_projection.weight is not decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight

return RobertaEncDecModel(encoder, decoder)

@staticmethod
def read_args_from_roberta(roberta_args: argparse.Namespace):
# TODO: this would become easier if encoder/decoder where using a similar
# TransformerConfig object
args = argparse.Namespace(**vars(roberta_args))
attr_map = [
("encoder_attention_heads", "decoder_attention_heads"),
("encoder_embed_dim", "decoder_embed_dim"),
("encoder_embed_dim", "decoder_output_dim"),
("encoder_normalize_before", "decoder_normalize_before"),
("encoder_layers_to_keep", "decoder_layers_to_keep"),
("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"),
("encoder_layerdrop", "decoder_layerdrop"),
("encoder_layers", "decoder_layers"),
("encoder_learned_pos", "decoder_learned_pos"),
# should this be set from here ?
("max_positions", "max_target_positions"),
]
for k1, k2 in attr_map:
setattr(args, k2, getattr(roberta_args, k1))

args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta
return args

def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
super().upgrade_state_dict_named(state_dict, name)
old_keys = list(state_dict.keys())

# rename decoder -> encoder before upgrading children modules
for k in old_keys:
if k.startswith(prefix + "encoder.lm_head"):
state_dict.pop(k)
continue
new_k = k
new_k = new_k.replace(".sentence_encoder.", ".")
new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.")
if k == new_k:
continue
# print(k, "->", new_k)
state_dict[new_k] = state_dict.pop(k)


@register_model_architecture("roberta_enc_dec", "roberta_enc_dec")
def base_enc_dec_architecture(args):
args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False)
args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None)
args.pretrained_decoder = getattr(args, "pretrained_decoder", None)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)

roberta.base_architecture(args)
6 changes: 3 additions & 3 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs
**kwargs,
):
if classification_head_name is not None:
features_only = True
Expand Down Expand Up @@ -259,7 +259,7 @@ def from_pretrained(
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
**kwargs
**kwargs,
):
from fairseq import hub_utils

Expand Down Expand Up @@ -464,7 +464,7 @@ def forward(
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
**unused
**unused,
):
"""
Args:
Expand Down
40 changes: 28 additions & 12 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
(default: False).
"""

def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
Expand Down Expand Up @@ -727,7 +734,11 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
)

self.adaptive_softmax = None
self.output_projection = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(args, dictionary, embed_tokens)

def build_output_projection(self, args, dictionary, embed_tokens):
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
Expand Down Expand Up @@ -789,7 +800,7 @@ def forward(
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
Expand All @@ -802,6 +813,7 @@ def forward(
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""

x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
Expand All @@ -810,6 +822,7 @@ def forward(
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)

if not features_only:
x = self.output_layer(x)
return x, extra
Expand Down Expand Up @@ -866,9 +879,19 @@ def extract_features_scriptable(
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1

enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None:
enc = encoder_out["encoder_out"][0]
padding_mask = encoder_out["encoder_padding_mask"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"

# embed positions
positions = None
if self.embed_positions is not None:
Expand Down Expand Up @@ -916,15 +939,8 @@ def extract_features_scriptable(

x, layer_attn, _ = layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
Expand Down
11 changes: 10 additions & 1 deletion fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,16 @@ def forward(
is_tpu = query.device.type == "xla"

tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, key_embed_dim = key.size()
if not torch.jit.is_scripting():
assert (key_bsz, key_embed_dim) == (bsz, embed_dim)
assert value is not None
assert (src_len, bsz, embed_dim) == value.shape


if (
not self.onnx_trace
Expand Down Expand Up @@ -262,6 +270,7 @@ def forward(
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
Expand Down Expand Up @@ -290,7 +299,7 @@ def forward(
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
assert k.size(1) == src_len

# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
Expand Down
Loading

0 comments on commit c2e8904

Please sign in to comment.