Skip to content

Commit

Permalink
hydra fairseq 3 - inherit from legacy for fairseq classes
Browse files Browse the repository at this point in the history
Summary: hydra fairseq 3 - inherit from legacy for fairseq classes

Reviewed By: alexeib

Differential Revision: D23375457

fbshipit-source-id: ef9d19f2d02f2326eea44a70f1f6e1668b420840
  • Loading branch information
Mu Tian authored and facebook-github-bot committed Sep 10, 2020
1 parent df45f42 commit 42c5dcb
Show file tree
Hide file tree
Showing 40 changed files with 257 additions and 73 deletions.
113 changes: 113 additions & 0 deletions docs/hydra_integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@


## Hydra

Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads.

## Train models with hydra interface

#### Provide parameters in `.yaml` files
For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below).

- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example:

```
defaults:
- params: training_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt
```

- Provide generic parameters common across different training jobs: `config/params/training_params.yaml`
- Provide task parameters: `config/task/language_modeling.yaml`
- Provide model parameters: `config/model/transformer_lm.yaml`
- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
- Provide optimizer parameters: `config/optimizer/adam.yaml`
- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml`

#### Command line overriding
`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command:

```
# task.data is requested field marked by `???` in yaml
python fairseq_cli/train_hydra.py \
task.data=/private/home/abaevski/data/wiki103 \
```

Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits)

```
python fairseq_cli/train_hydra.py
params=training_params \
task=language_modeling \
task.data=/private/home/abaevski/data/wiki103 \
task.tokens_per_sample=512 \
task.sample_break_mode=none \
model=transformer_lm \
model.share_decoder_input_output_embed=true \
model.dropout=0.1 \
optimizer=adam \
optimizer.adam_betas="'(0.9, 0.98)'" \
optimizer.weight_decay=0.01 \
lr_scheduler=inverse_sqrt \
lr_scheduler.warmup_updates=4000 \
lr_scheduler.warmup_init_lr=1e-07 \
criterion=cross_entropy \
params.common.fp16=true \
params.common.log_format=json \
params.common.log_interval=1 \
params.dataset.max_tokens=1024 \
params.dataset.num_workers=4 \
params.optimization.update_freq=[16] \
params.optimization.max_update=50000 \
params.optimization.clip_norm=0.0 \
params.optimization.lr=[0.0005] \
params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
params.checkpoint.save_interval_updates=10
```

## Migrate existing/Creating new modules to hydra interface

In each of the modules we want to migrated/create with hydra interface, fundamentally we need to

- Provide a dataclass that layouts the parameters used in the module.

- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility.

- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below.

#### Migrated examples:

- Task: `fairseq/tasks/language_modeling.py`

- Model: `fairseq/models/transformer_lm.py`

- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py`

- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py`

- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py`


## Interpolate parameters across different places

## Support of legacy interface
If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned.

```
python fairseq_cli/train.py --task language_modeling \
/private/home/abaevski/data/wiki103 \
--save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
--arch transformer_lm --share-decoder-input-output-embed \
--dropout 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--tokens-per-sample 512 --sample-break-mode none \
--max-tokens 1024 --update-freq 16 \
--fp16 \
--max-update 50000 --log-format json --log-interval 1 --num-workers 4 \
--save-interval-updates 10
```
4 changes: 2 additions & 2 deletions examples/roberta/commonsense_qa/commonsense_qa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
RightPadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask


@register_task('commonsense_qa')
class CommonsenseQATask(FairseqTask):
class CommonsenseQATask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Commonsense QA."""

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions examples/roberta/wsc/wsc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
PadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask

from . import wsc_utils


@register_task('wsc')
class WSCTask(FairseqTask):
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions examples/speech_recognition/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol

Expand Down Expand Up @@ -66,7 +66,7 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict):


@register_task("speech_recognition")
class SpeechRecognitionTask(FairseqTask):
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""
Expand Down
4 changes: 2 additions & 2 deletions fairseq/benchmark/dummy_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask


logger = logging.getLogger(__name__)


@register_task('dummy_lm')
class DummyLMTask(FairseqTask):
class DummyLMTask(LegacyFairseqTask):

@staticmethod
def add_args(parser):
Expand Down
4 changes: 2 additions & 2 deletions fairseq/benchmark/dummy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask


logger = logging.getLogger(__name__)


@register_task('dummy_masked_lm')
class DummyMaskedLMTask(FairseqTask):
class DummyMaskedLMTask(LegacyFairseqTask):

@staticmethod
def add_args(parser):
Expand Down
4 changes: 2 additions & 2 deletions fairseq/benchmark/dummy_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask


logger = logging.getLogger(__name__)


@register_task('dummy_mt')
class DummyMTTask(FairseqTask):
class DummyMTTask(LegacyFairseqTask):

@staticmethod
def add_args(parser):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fairseq_optimizer import FairseqOptimizer, LegacyFairseqOptimizer # noqa
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF # noqa
from fairseq.optim.shard import shard_
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch.optim

from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer


@register_optimizer('adadelta')
class Adadelta(FairseqOptimizer):
class Adadelta(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import torch
import torch.optim

from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer


@register_optimizer('adafactor')
class FairseqAdafactor(FairseqOptimizer):
class FairseqAdafactor(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adafactor(params, **self.optimizer_config)
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch.optim

from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer


@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
class Adagrad(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import torch
import torch.optim

from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer


@register_optimizer('adamax')
class FairseqAdamax(FairseqOptimizer):
class FairseqAdamax(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config)
Expand Down
6 changes: 6 additions & 0 deletions fairseq/optim/fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,9 @@ def supports_flat_params(self):

def average_params(self):
pass


class LegacyFairseqOptimizer(FairseqOptimizer):

def __init__(self, args):
self.args = args
4 changes: 2 additions & 2 deletions fairseq/optim/fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim import register_optimizer, LegacyFairseqOptimizer


@register_optimizer('lamb')
class FairseqLAMB(FairseqOptimizer):
class FairseqLAMB(LegacyFairseqOptimizer):
"""LAMB optimizer."""

def __init__(self, args, params):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa


build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(
Expand Down
11 changes: 11 additions & 0 deletions fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from .. import FairseqOptimizer
from argparse import Namespace


class FairseqLRScheduler(object):
Expand Down Expand Up @@ -40,3 +41,13 @@ def step(self, epoch, val_loss=None):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()


class LegacyFairseqLRScheduler(FairseqLRScheduler):

def __init__(self, args: Namespace, optimizer):
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError('optimizer must be an instance of FairseqOptimizer')
self.args = args
self.optimizer = optimizer
self.best = None
4 changes: 2 additions & 2 deletions fairseq/optim/lr_scheduler/fixed_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler


@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
class FixedSchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""

def __init__(self, args, optimizer):
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler


@register_lr_scheduler('polynomial_decay')
class PolynomialDecaySchedule(FairseqLRScheduler):
class PolynomialDecaySchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""

def __init__(self, args, optimizer):
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch.optim.lr_scheduler

from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler


@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
class ReduceLROnPlateau(LegacyFairseqLRScheduler):
"""
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase
Expand Down
4 changes: 2 additions & 2 deletions fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
import math


@register_lr_scheduler('tri_stage')
class TriStageLRSchedule(FairseqLRScheduler):
class TriStageLRSchedule(LegacyFairseqLRScheduler):
"""Tristage learning rate schedulr
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
Expand Down
Loading

0 comments on commit 42c5dcb

Please sign in to comment.