Skip to content

Commit

Permalink
Adding Hydra based trainer target to fairseq in fbcode
Browse files Browse the repository at this point in the history
Summary: Adding fairseq entrypoint section of e2e pipeline so FairseqConfig to hydra_main, runs smoothly

Reviewed By: jieru-hu

Differential Revision: D29714729

fbshipit-source-id: e3694e0037bb4c4f69208c1d6ec7df91d42fb588
  • Loading branch information
EdanSneh authored and facebook-github-bot committed Aug 3, 2021
1 parent db4f96b commit fe15926
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion fairseq/dataclass/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def add_defaults(cfg: DictConfig) -> None:
field_cfg = DictConfig({"_name": field_cfg})
field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]

name = field_cfg.get("_name")
name = getattr(field_cfg, "_name", None)

if k == "task":
dc = TASK_DATACLASS_REGISTRY.get(name)
Expand Down
22 changes: 15 additions & 7 deletions fairseq_cli/hydra_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fairseq_cli.train import main as pre_main
from fairseq import distributed_utils, metrics
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import omegaconf_no_object_check
from fairseq.utils import reset_logging

import hydra
Expand All @@ -24,25 +25,32 @@

@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> float:
_hydra_main(cfg)


def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
add_defaults(cfg)

if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that
else:
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)

cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
# check if directly called or called through hydra_main
if HydraConfig.initialized():
with open_dict(cfg):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True)

with omegaconf_no_object_check():
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
OmegaConf.set_struct(cfg, True)

try:
if cfg.common.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, pre_main)
distributed_utils.call_main(cfg, pre_main, **kwargs)
else:
distributed_utils.call_main(cfg, pre_main)
distributed_utils.call_main(cfg, pre_main, **kwargs)
except BaseException as e:
if not cfg.common.suppress_crashes:
raise
Expand Down

0 comments on commit fe15926

Please sign in to comment.