forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? too many of them actually ^^ ## What does this PR do? This is a rewrite of fairinternal/fairseq-py#1538 following the discussion there, and taking into account the proposed fairinternal/fairseq-py#1560 from Myle. it brings online backtranslation to fairseq. It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: fairinternal/fairseq-py#1614 Reviewed By: myleott Differential Revision: D27157296 Pulled By: gwenzek fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21
- Loading branch information
1 parent
7dafb05
commit c2e8904
Showing
16 changed files
with
1,472 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,8 @@ def __init__( | |
**kwargs, | ||
) | ||
) | ||
self.sizes = src_dataset.sizes | ||
|
||
|
||
def __getitem__(self, index): | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import argparse | ||
import logging | ||
|
||
import torch.nn as nn | ||
import fairseq.checkpoint_utils | ||
from fairseq.models import ( | ||
FairseqEncoderDecoderModel, | ||
register_model, | ||
register_model_architecture, | ||
) | ||
from fairseq.models.transformer import TransformerDecoder | ||
from fairseq.models.roberta import model as roberta | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@register_model("roberta_enc_dec") | ||
class RobertaEncDecModel(FairseqEncoderDecoderModel): | ||
@staticmethod | ||
def add_args(parser): | ||
parser.add_argument( | ||
"--pretrained-mlm-checkpoint", | ||
default=None, | ||
type=str, | ||
metavar="PRETRAINED", | ||
help="path to pretrained mlm checkpoint", | ||
) | ||
parser.add_argument( | ||
"--pretrained-decoder", action="store_true", help="reload decoder" | ||
) | ||
parser.add_argument( | ||
"--hack-layernorm-embedding", | ||
action="store_true", | ||
help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False", | ||
) | ||
parser.add_argument( | ||
"--share-decoder-input-output-embed", | ||
action="store_true", | ||
help="share decoder input and output embeddings", | ||
) | ||
parser.add_argument( | ||
"--share-all-embeddings", | ||
action="store_true", | ||
help="share encoder, decoder and output embeddings" | ||
" (requires shared dictionary and embed dim)", | ||
) | ||
|
||
@classmethod | ||
def build_model(cls, args, task): | ||
"""Build a new model instance.""" | ||
|
||
# make sure all arguments are present | ||
base_enc_dec_architecture(args) | ||
if args.pretrained_mlm_checkpoint: | ||
arg_overrides = None | ||
if args.hack_layernorm_embedding: | ||
arg_overrides = {"layernorm_embedding": False} | ||
loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task( | ||
[args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides | ||
) | ||
([roberta_enc], _cfg, _task) = loaded | ||
else: | ||
# Do we need to edit untie_weights here ? | ||
share_in_out = ( | ||
args.share_decoder_input_output_embed or args.share_all_embeddings | ||
) | ||
args.untie_weights_roberta = not share_in_out | ||
if args.hack_layernorm_embedding: | ||
args.layernorm_embedding = False | ||
args.encoder_normalize_before = False | ||
roberta_enc = roberta.RobertaModel.build_model(args, task) | ||
|
||
return cls.from_roberta(roberta_enc, args, task.source_dictionary) | ||
|
||
@staticmethod | ||
def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary): | ||
encoder = roberta_enc.encoder.sentence_encoder | ||
vocab_size, embed_dim = encoder.embed_tokens.weight.shape | ||
|
||
if args.share_all_embeddings: | ||
lm_head = roberta_enc.encoder.lm_head | ||
assert encoder.embed_tokens.weight is lm_head.weight, ( | ||
"Can't use --share-all-embeddings with a model " | ||
"that was pretraiend with --untie-weights-roberta_enc" | ||
) | ||
else: | ||
lm_head = roberta.RobertaLMHead( | ||
embed_dim, vocab_size, roberta_enc.args.activation_fn | ||
) | ||
|
||
dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad()) | ||
if args.share_all_embeddings or args.share_decoder_input_output_embed: | ||
# Note: I wasn't able to use Embedding _weight parameter to achive this sharing. | ||
dec_embs.weight = lm_head.weight | ||
|
||
decoder = TransformerDecoder( | ||
RobertaEncDecModel.read_args_from_roberta(roberta_enc.args), | ||
dictionary, | ||
dec_embs, | ||
no_encoder_attn=False, | ||
output_projection=lm_head, | ||
) | ||
if getattr(args, "pretrained_decoder", False): | ||
decoder_dict = encoder.state_dict() | ||
|
||
# TODO: hide setting "encoder_attn" layers behind a flag. | ||
for k, w in list(decoder_dict.items()): | ||
if ".self_attn" in k: | ||
k_enc_attn = k.replace(".self_attn", ".encoder_attn") | ||
decoder_dict[k_enc_attn] = w.detach().clone() | ||
|
||
for k, w in lm_head.state_dict().items(): | ||
decoder_dict["output_projection." + k] = w | ||
|
||
missing_keys, unexpected_keys = decoder.load_state_dict( | ||
decoder_dict, strict=False | ||
) | ||
# missing_keys = [m for m in missing_keys if ".encoder_attn" not in m] | ||
assert not missing_keys and not unexpected_keys, ( | ||
"Failed to load state dict. " | ||
f"Missing keys: {missing_keys}. " | ||
f"Unexpected keys: {unexpected_keys}." | ||
) | ||
|
||
if args.share_all_embeddings: | ||
assert decoder.output_projection.weight is decoder.embed_tokens.weight | ||
assert encoder.embed_tokens.weight is decoder.embed_tokens.weight | ||
elif args.share_decoder_input_output_embed: | ||
assert decoder.output_projection.weight is decoder.embed_tokens.weight | ||
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight | ||
else: | ||
assert decoder.output_projection.weight is not decoder.embed_tokens.weight | ||
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight | ||
|
||
return RobertaEncDecModel(encoder, decoder) | ||
|
||
@staticmethod | ||
def read_args_from_roberta(roberta_args: argparse.Namespace): | ||
# TODO: this would become easier if encoder/decoder where using a similar | ||
# TransformerConfig object | ||
args = argparse.Namespace(**vars(roberta_args)) | ||
attr_map = [ | ||
("encoder_attention_heads", "decoder_attention_heads"), | ||
("encoder_embed_dim", "decoder_embed_dim"), | ||
("encoder_embed_dim", "decoder_output_dim"), | ||
("encoder_normalize_before", "decoder_normalize_before"), | ||
("encoder_layers_to_keep", "decoder_layers_to_keep"), | ||
("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"), | ||
("encoder_layerdrop", "decoder_layerdrop"), | ||
("encoder_layers", "decoder_layers"), | ||
("encoder_learned_pos", "decoder_learned_pos"), | ||
# should this be set from here ? | ||
("max_positions", "max_target_positions"), | ||
] | ||
for k1, k2 in attr_map: | ||
setattr(args, k2, getattr(roberta_args, k1)) | ||
|
||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | ||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | ||
args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta | ||
return args | ||
|
||
def upgrade_state_dict_named(self, state_dict, name): | ||
prefix = name + "." if name != "" else "" | ||
super().upgrade_state_dict_named(state_dict, name) | ||
old_keys = list(state_dict.keys()) | ||
|
||
# rename decoder -> encoder before upgrading children modules | ||
for k in old_keys: | ||
if k.startswith(prefix + "encoder.lm_head"): | ||
state_dict.pop(k) | ||
continue | ||
new_k = k | ||
new_k = new_k.replace(".sentence_encoder.", ".") | ||
new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.") | ||
if k == new_k: | ||
continue | ||
# print(k, "->", new_k) | ||
state_dict[new_k] = state_dict.pop(k) | ||
|
||
|
||
@register_model_architecture("roberta_enc_dec", "roberta_enc_dec") | ||
def base_enc_dec_architecture(args): | ||
args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False) | ||
args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None) | ||
args.pretrained_decoder = getattr(args, "pretrained_decoder", None) | ||
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | ||
args.share_decoder_input_output_embed = getattr( | ||
args, "share_decoder_input_output_embed", False | ||
) | ||
|
||
roberta.base_architecture(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.