forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RXF OS Implementation (facebookresearch#2455)
Summary: ## What does this PR do? Implements R3F and R4F coming from Facebook Research: https://arxiv.org/abs/2008.03156 This code was used to generate all the results from the paper excluding probing results. Pull Request resolved: facebookresearch#2455 Reviewed By: myleott Differential Revision: D23444863 Pulled By: AkshatSh fbshipit-source-id: b724a6d6cc9cebfdb4bd219828afbb5679f2259b
- Loading branch information
1 parent
698820b
commit f2fa071
Showing
11 changed files
with
483 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
[Better Fine-Tuning by Reducing Representational Collapse](https://arxiv.org/abs/2008.03156) | ||
===================== | ||
This repo contains the code to replicate all experiments from the _Better Fine-Tuning by Reducing Representational Collapse_ paper excluding the probing results. | ||
|
||
The R3F sentence prediction criterion is registered as `sentence_prediction_r3f` while the label smoothing version of it is implemented as `label_smoothed_cross_entropy_r3f`. The R4F version of the sentence prediction criterion can be achieved by applying spectral norm to the classification head via the `--spectral-norm-classification-head` parameter. | ||
|
||
## Hyper-parameters | ||
Our methods introduce 3 new hyper-parameters; `--eps` which sets the standard deviation or range of the distribution we're sampling from, `--r3f-lambda` which controls the combining of logistic loss and noisy KL loss and `--noise-type` which controls which parametric distribution we use ('normal', 'uniform'). | ||
|
||
For example to run R3F on RTE from GLUE | ||
|
||
``` | ||
TOTAL_NUM_UPDATES=3120 | ||
WARMUP_UPDATES=187 | ||
LR=1e-05 | ||
NUM_CLASSES=2 | ||
MAX_SENTENCES=8 # Batch size. | ||
ROBERTA_PATH=/path/to/roberta/model.pt | ||
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin \ | ||
--restore-file $ROBERTA_PATH \ | ||
--max-positions 512 \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--max-tokens 4400 \ | ||
--task sentence_prediction \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--required-batch-size-multiple 1 \ | ||
--init-token 0 --separator-token 2 \ | ||
--arch roberta_large \ | ||
--criterion sentence_prediction_r3f \ | ||
--num-classes $NUM_CLASSES \ | ||
--dropout 0.1 --attention-dropout 0.1 \ | ||
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ | ||
--clip-norm 0.0 \ | ||
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ | ||
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ | ||
--max-epoch 10 \ | ||
--find-unused-parameters \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ | ||
--noise-type uniform --r3f-lambda 0.7 \ | ||
--user-dir examples/rxf; | ||
``` | ||
|
||
## Citation | ||
```bibtex | ||
@article{aghajanyan2020better, | ||
title={Better Fine-Tuning by Reducing Representational Collapse}, | ||
author={Aghajanyan, Armen and Shrivastava, Akshat and Gupta, Anchit and Goyal, Naman and Zettlemoyer, Luke and Gupta, Sonal}, | ||
journal={arXiv preprint arXiv:2008.03156}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from . import src # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from . import label_smoothed_cross_entropy_r3f, sentence_prediction_r3f # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from fairseq import metrics, utils | ||
from fairseq.criterions import FairseqCriterion, register_criterion | ||
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss | ||
|
||
|
||
@register_criterion("label_smoothed_cross_entropy_r3f") | ||
class LabelSmoothedCrossEntropyR3FCriterion(FairseqCriterion): | ||
def __init__( | ||
self, task, sentence_avg, label_smoothing, eps, r3f_lambda, noise_type | ||
): | ||
super().__init__(task) | ||
self.sentence_avg = sentence_avg | ||
self.label_smoothing = label_smoothing | ||
self.eps = eps | ||
self.r3f_lambda = r3f_lambda | ||
self.noise_type = noise_type | ||
if self.noise_type in {"normal"}: | ||
self.noise_sampler = torch.distributions.normal.Normal( | ||
loc=0.0, scale=self.eps | ||
) | ||
elif self.noise_type == "uniform": | ||
self.noise_sampler = torch.distributions.uniform.Uniform( | ||
low=-self.eps, high=self.eps | ||
) | ||
else: | ||
raise Exception(f"unrecognized noise type {self.noise_type}") | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add criterion-specific arguments to the parser.""" | ||
# fmt: off | ||
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', | ||
help='epsilon for label smoothing, 0 means no label smoothing') | ||
parser.add_argument('--eps', type=float, default=1e-5, | ||
help='noise eps') | ||
parser.add_argument('--r3f-lambda', type=float, default=1.0, | ||
help='lambda for combining logistic loss and noisy KL loss') | ||
parser.add_argument('--noise-type', type=str, default='normal', | ||
choices=['normal', 'uniform'], | ||
help='type of noises') | ||
# fmt: on | ||
|
||
def _get_symm_kl(self, noised_logits, input_logits): | ||
return ( | ||
F.kl_div( | ||
F.log_softmax(noised_logits, dim=-1, dtype=torch.float32), | ||
F.softmax(input_logits, dim=-1, dtype=torch.float32), | ||
None, | ||
None, | ||
"sum", | ||
) | ||
+ F.kl_div( | ||
F.log_softmax(input_logits, dim=-1, dtype=torch.float32), | ||
F.softmax(noised_logits, dim=-1, dtype=torch.float32), | ||
None, | ||
None, | ||
"sum", | ||
) | ||
) / noised_logits.size(0) | ||
|
||
def forward(self, model, sample, reduce=True): | ||
"""Compute the loss for the given sample. | ||
Returns a tuple with three elements: | ||
1) the loss | ||
2) the sample size, which is used as the denominator for the gradient | ||
3) logging outputs to display while training | ||
""" | ||
token_embeddings = model.encoder.embed_tokens(sample["net_input"]["src_tokens"]) | ||
input_logits, extra = model(**sample["net_input"]) | ||
loss, nll_loss = self.compute_loss( | ||
model, (input_logits, extra), sample, reduce=reduce | ||
) | ||
sample_size = ( | ||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | ||
) | ||
|
||
if model.training: | ||
noise = self.noise_sampler.sample(sample_shape=token_embeddings.shape).to( | ||
token_embeddings | ||
) | ||
noised_embeddings = token_embeddings.clone() + noise | ||
|
||
noised_logits, _ = model( | ||
**sample["net_input"], token_embeddings=noised_embeddings | ||
) | ||
symm_kl = self._get_symm_kl(noised_logits, input_logits) | ||
|
||
if model.training: | ||
symm_kl = symm_kl * sample_size | ||
loss = loss + self.r3f_lambda * symm_kl | ||
|
||
logging_output = { | ||
"loss": loss.data, | ||
"nll_loss": nll_loss.data, | ||
"ntokens": sample["ntokens"], | ||
"nsentences": sample["target"].size(0), | ||
"sample_size": sample_size, | ||
} | ||
|
||
if model.training: | ||
logging_output.update( | ||
symm_kl=utils.item(symm_kl.data) if reduce else symm_kl.data | ||
) | ||
|
||
return loss, sample_size, logging_output | ||
|
||
def compute_loss(self, model, net_output, sample, reduce=True): | ||
lprobs = model.get_normalized_probs(net_output, log_probs=True) | ||
lprobs = lprobs.view(-1, lprobs.size(-1)) | ||
target = model.get_targets(sample, net_output).view(-1, 1) | ||
loss, nll_loss = label_smoothed_nll_loss( | ||
lprobs, | ||
target, | ||
self.label_smoothing, | ||
ignore_index=self.padding_idx, | ||
reduce=reduce, | ||
) | ||
return loss, nll_loss | ||
|
||
@staticmethod | ||
def reduce_metrics(logging_outputs) -> None: | ||
"""Aggregate logging outputs from data parallel training.""" | ||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | ||
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) | ||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | ||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | ||
symm_kl_sum = sum(log.get("symm_kl", 0) for log in logging_outputs) | ||
|
||
metrics.log_scalar("symm_kl", symm_kl_sum / sample_size, sample_size, round=3) | ||
metrics.log_scalar( | ||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3 | ||
) | ||
metrics.log_scalar( | ||
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 | ||
) | ||
metrics.log_derived( | ||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) | ||
) | ||
|
||
@staticmethod | ||
def logging_outputs_can_be_summed() -> bool: | ||
""" | ||
Whether the logging outputs returned by `forward` can be summed | ||
across workers prior to calling `reduce_metrics`. Setting this | ||
to True will improves distributed training speed. | ||
""" | ||
return True |
Oops, something went wrong.