Skip to content

Commit

Permalink
Data2vec prelim (facebookresearch#2929)
Browse files Browse the repository at this point in the history
Summary:
Preliminaries for data2vec release, include some minor improvements and bug fixes

Most important change is that we now default to raising an exception when fields in config do not have a corresponding field in the model dataclass

Pull Request resolved: fairinternal/fairseq-py#2929

Reviewed By: wnhsu

Differential Revision: D33649708

Pulled By: alexeib

fbshipit-source-id: 629bdb4c361550740b451c570c2005bb956c6fcb
  • Loading branch information
alexeib authored and facebook-github-bot committed Jan 20, 2022
1 parent a59cea5 commit 995c204
Show file tree
Hide file tree
Showing 44 changed files with 296 additions and 97 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,8 @@ experimental/*

# Weights and Biases logs
wandb/

# Hydra artifacts
nohup.out
multirun
outputs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def make_dataset(type, dictionary, data_split, combine):
split_path = get_path(type, data_split)

dataset = data_utils.load_indexed_dataset(
split_path, dictionary, combine=combine,
split_path,
dictionary,
combine=combine,
)
return dataset

Expand Down Expand Up @@ -241,7 +243,8 @@ def load_split(data_split, metric):
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens, pad_idx=self.source_dictionary.pad(),
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
Expand All @@ -250,11 +253,16 @@ def load_split(data_split, metric):
"target": label,
}

dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],)
dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)

assert len(dataset) % self.cfg.mt_beam == 0, (
"dataset size (%d) is not a multiple of beam size (%d)"
% (len(dataset), self.cfg.mt_beam)
assert (
len(dataset) % self.cfg.mt_beam == 0
), "dataset size (%d) is not a multiple of beam size (%d)" % (
len(dataset),
self.cfg.mt_beam,
)

# no need to shuffle valid/test sets
Expand All @@ -270,7 +278,10 @@ def load_split(data_split, metric):
start_idx, (self.cfg.mt_beam, 1)
).transpose().reshape(-1)

dataset = SortDataset(dataset, sort_order=[shuffle],)
dataset = SortDataset(
dataset,
sort_order=[shuffle],
)

logger.info(f"Loaded {split} with #samples: {len(dataset)}")

Expand Down Expand Up @@ -313,17 +324,21 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens, pad_idx=self.source_dictionary.pad(),
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}

return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],)
return NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)

def build_model(self, cfg: FairseqDataclass):
def build_model(self, cfg: FairseqDataclass, from_checkpoint: bool = False):
return super().build_model(cfg)

def build_generator(self, args):
Expand Down
2 changes: 1 addition & 1 deletion examples/laser/laser_src/laser_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def setup_task(cls, args, **kwargs):
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)

# Experimental overriding for backtranslation
def build_model(self, args):
def build_model(self, args, from_checkpoint=False):
model = models.build_model(args, self)
return model

Expand Down
2 changes: 1 addition & 1 deletion examples/roberta/commonsense_qa/commonsense_qa_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def binarize(s, append_bos=False):
self.datasets[split] = dataset
return self.datasets[split]

def build_model(self, args):
def build_model(self, args, from_checkpoint=False):
from fairseq import models

model = models.build_model(args, self)
Expand Down
7 changes: 5 additions & 2 deletions examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ ax_config:
max_trials: 128
early_stop:
minimize: true
max_epochs_without_improvement: 32
epsilon: 1.0e-05
max_epochs_without_improvement: 10
epsilon: 0.025
experiment:
name: ${dataset.gen_subset}
objective_name: wer
Expand All @@ -24,3 +24,6 @@ ax_config:
decoding.wordscore:
type: range
bounds: [-5.0, 5.0]
decoding.silweight:
type: range
bounds: [ -8.0, 0.0 ]
4 changes: 2 additions & 2 deletions examples/speech_recognition/new/conf/infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ hydra:
run:
dir: ${common_eval.results_path}/${dataset.gen_subset}
sweep:
dir: ${common_eval.results_path}
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
subdir: ${dataset.gen_subset}
common_eval:
results_path: null
path: null
post_process: letter
quiet: true
dataset:
max_tokens: 1000000
max_tokens: 3000000
gen_subset: test
distributed_training:
distributed_world_size: 1
Expand Down
2 changes: 2 additions & 0 deletions examples/speech_recognition/new/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def main(cfg: InferConfig) -> float:
if not cfg.common.cpu and not torch.cuda.is_available():
raise ValueError("CUDA not found; set `cpu=True` to run without CUDA")

logger.info(cfg.common_eval.path)

with InferenceProcessor(cfg) as processor:
for sample in processor:
processor.process_sample(sample)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):

super().__init__(cfg, src_dict, tgt_dict)

def build_model(self, cfg):
def build_model(self, cfg, from_checkpoint=False):
from fairseq import models

model = models.build_model(cfg, self)
Expand Down
2 changes: 1 addition & 1 deletion examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def reduce_metrics(self, logging_outputs, criterion):
/ meters["nsentences"].sum,
)

def build_model(self, cfg: FairseqDataclass):
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(cfg)

return model
2 changes: 1 addition & 1 deletion examples/wav2vec/unsupervised/w2vu_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def main(cfg: UnsupGenerateConfig, model=None):

lm_ppl = max(cfg.min_lm_ppl, lm_ppl)

if not cfg.unsupervised_tuning == 0:
if not cfg.unsupervised_tuning:
weighted_score = wer
else:
weighted_score = math.log(lm_ppl) * (vt_diff or 1.0)
Expand Down
17 changes: 13 additions & 4 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ast
import collections
import contextlib
import inspect
import logging
import os
import re
Expand Down Expand Up @@ -111,7 +112,7 @@ def is_better(a, b):
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0:
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
if cfg.write_checkpoints_asynchronously:
Expand Down Expand Up @@ -215,7 +216,9 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if cfg.finetune_from_model is not None and first_launch:
if first_launch and cfg.get("continue_once", None) is not None:
checkpoint_path = cfg.continue_once
elif cfg.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(cfg.finetune_from_model):
Expand Down Expand Up @@ -460,7 +463,13 @@ def load_model_ensemble_and_task(
)
else:
# model parallel checkpoint or unsharded checkpoint
model = task.build_model(cfg.model)
# support old external tasks

argspec = inspect.getfullargspec(task.build_model)
if "from_checkpoint" in argspec.args:
model = task.build_model(cfg.model, from_checkpoint=True)
else:
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
Expand Down Expand Up @@ -605,7 +614,7 @@ def _upgrade_state_dict(state):
# use stateful training data iterator
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"]["epoch"],
"epoch": state["extra_state"].get("epoch", 0),
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
}

Expand Down
2 changes: 1 addition & 1 deletion fairseq/criterions/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask):
cfg.wer_word_score,
) = eval(cfg.wer_args)

if cfg.wer_kenlm_model is not None:
if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

dec_args = Namespace()
Expand Down
21 changes: 19 additions & 2 deletions fairseq/criterions/model_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dataclasses import dataclass, field
from typing import Dict, List

import torch

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
Expand Down Expand Up @@ -49,7 +51,6 @@ def __init__(self, task, loss_weights=None, log_keys=None):
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])

sample_size = net_output["sample_size"]
scaled_losses = {}

if hasattr(model, "get_losses"):
Expand All @@ -71,6 +72,12 @@ def forward(self, model, sample, reduce=True):
scaled_losses[lk] = coef * p.float()

loss = sum(scaled_losses.values())

if "sample_size" in net_output:
sample_size = net_output["sample_size"]
else:
sample_size = loss.numel()

if reduce and loss.numel() > 1:
loss = loss.sum()

Expand All @@ -84,12 +91,22 @@ def forward(self, model, sample, reduce=True):

for lk in self.log_keys:
if lk in net_output and net_output[lk] is not None:
logging_output[lk] = float(net_output[lk])
if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1:
logging_output[lk] = float(net_output[lk])
else:
for i, v in enumerate(net_output[lk]):
logging_output[f"{lk}_{i}"] = float(v)

if len(scaled_losses) > 1:
for lk, l in scaled_losses.items():
if l.numel() > 1:
l = l.sum()
logging_output[f"loss_{lk}"] = l.item()

if "logs" in net_output:
for lgw in net_output["logs"]:
logging_output[lgw] = net_output["logs"][lgw]

return loss, sample_size, logging_output

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def compute_mask_indices(
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
pct_holes: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Expand Down Expand Up @@ -510,8 +512,14 @@ def arrange(s, e, length, keep_length):

min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
if len(mask_idc) > min_len and require_same_masks:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if pct_holes > 0:
num_holes = np.rint(len(mask_idc) * pct_holes).astype(int)
mask_idc = np.random.choice(
mask_idc, len(mask_idc) - num_holes, replace=False
)

mask[i, mask_idc] = True

return mask
Expand Down
6 changes: 6 additions & 0 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ class CheckpointConfig(FairseqDataclass):
"(default: <save-dir>/checkpoint_last.pt"
},
)
continue_once: Optional[str] = field(
default=None,
metadata={
"help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present"
},
)
finetune_from_model: Optional[str] = field(
default=None,
metadata={
Expand Down
2 changes: 1 addition & 1 deletion fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
cfg[k] = overrides[k]


def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=True):
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False):
if remove_missing:

if is_dataclass(dc):
Expand Down
5 changes: 3 additions & 2 deletions fairseq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import importlib
import os

from contextlib import ExitStack

from fairseq.dataclass import FairseqDataclass
Expand Down Expand Up @@ -52,7 +53,7 @@
]


def build_model(cfg: FairseqDataclass, task):
def build_model(cfg: FairseqDataclass, task, from_checkpoint=False):

model = None
model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)
Expand Down Expand Up @@ -86,7 +87,7 @@ def build_model(cfg: FairseqDataclass, task):
if isinstance(cfg, argparse.Namespace):
cfg = dc.from_namespace(cfg)
else:
cfg = merge_with_parent(dc(), cfg)
cfg = merge_with_parent(dc(), cfg, from_checkpoint)
else:
if model_type in ARCH_CONFIG_REGISTRY:
with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack():
Expand Down
Loading

0 comments on commit 995c204

Please sign in to comment.