Skip to content

Commit

Permalink
delegate namespace conversion to DC
Browse files Browse the repository at this point in the history
Summary: `populate_dataclass` is very basic in how it populates the dataclass. We might want more specific behaviour for some config dataclasses (like the hierarchical behaviour in TransformerConfig, see rest of stack). This diff move the populate logic to a `from_namespace` method in `FairseqDataclass` so that the a specific Dataclass can reimplement it.

Reviewed By: myleott

Differential Revision: D29521388

fbshipit-source-id: f3a6dc80e4ddfc9563c6e85c37c563173f193f4d
  • Loading branch information
Mortimerp9 authored and facebook-github-bot committed Jul 16, 2021
1 parent aa15dc9 commit 7ebdc24
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 20 deletions.
16 changes: 16 additions & 0 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def _get_argparse_alias(self, attribute_name: str) -> Any:
def _get_choices(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "choices")

@classmethod
def from_namespace(cls, args):
if isinstance(args, cls):
return args
else:
config = cls()
for k in config.__dataclass_fields__.keys():
if k.startswith("_"):
# private member, skip
continue
if hasattr(args, k):
setattr(config, k, getattr(args, k))

return config



@dataclass
class CommonConfig(FairseqDataclass):
Expand Down
14 changes: 0 additions & 14 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,20 +419,6 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
return cfg


def populate_dataclass(
dataclass: FairseqDataclass,
args: Namespace,
) -> FairseqDataclass:
for k in dataclass.__dataclass_fields__.keys():
if k.startswith("_"):
# private member, skip
continue
if hasattr(args, k):
setattr(dataclass, k, getattr(args, k))

return dataclass


def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
# this will be deprecated when we get rid of argparse and model_overrides logic

Expand Down
4 changes: 2 additions & 2 deletions fairseq/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextlib import ExitStack

from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent, populate_dataclass
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore
from omegaconf import open_dict, OmegaConf

Expand Down Expand Up @@ -84,7 +84,7 @@ def build_model(cfg: FairseqDataclass, task):
dc = MODEL_DATACLASS_REGISTRY[model_type]

if isinstance(cfg, argparse.Namespace):
cfg = populate_dataclass(dc(), cfg)
cfg = dc.from_namespace(cfg)
else:
cfg = merge_with_parent(dc(), cfg)
else:
Expand Down
4 changes: 2 additions & 2 deletions fairseq/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import Union
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import populate_dataclass, merge_with_parent
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig

Expand Down Expand Up @@ -45,7 +45,7 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs)
else:
choice = getattr(cfg, registry_name, None)
if choice in DATACLASS_REGISTRY:
cfg = populate_dataclass(DATACLASS_REGISTRY[choice](), cfg)
cfg = DATACLASS_REGISTRY[choice].from_namespace(cfg)

if choice is None:
if required:
Expand Down
4 changes: 2 additions & 2 deletions fairseq/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os

from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent, populate_dataclass
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore

from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
Expand All @@ -30,7 +30,7 @@ def setup_task(cfg: FairseqDataclass, **kwargs):
task = TASK_REGISTRY[task_name]
if task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = populate_dataclass(dc(), cfg)
cfg = dc.from_namespace(cfg)
else:
task_name = getattr(cfg, "_name", None)

Expand Down

0 comments on commit 7ebdc24

Please sign in to comment.