Skip to content

Commit

Permalink
Hierarchical Configs
Browse files Browse the repository at this point in the history
Summary:
This is a precursor to D29232595

The current behaviour to convert a dataclass to a namespace is that all the fields from all DCs in the field hierarchy are flattened at the top. This is also the legacy behaviour with `add_args`.

This is kind of cumbersome to build reusable Dataclasses as we need to make sure that each field has a unique  name. In the case of Transformer for instance, we have a Decoder and Encoder config that share a large part of their fields (embed_dim, layers, etc.). We can build a single dataclass for this that can be reused and extended in other implementations. To be then able to have  a flat namespace, instead of adding all subfields as is to the root namespace, we introduce the name of the field as prefix to the arg in the namespace.

So:
`model.decoder.embed_dim` becomes `decoder_embed_dim` and `model.encoder.embed_dim` becomes `encoder_embed_dim`.

Reviewed By: myleott, dianaml0

Differential Revision: D29521386

fbshipit-source-id: f4bef036f0eeb620c6d8709ce97f96ae288848ef
  • Loading branch information
Mortimerp9 authored and facebook-github-bot committed Jul 16, 2021
1 parent 7ebdc24 commit bc1504d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
31 changes: 26 additions & 5 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,27 @@ def gen_parser_from_dataclass(
parser: ArgumentParser,
dataclass_instance: FairseqDataclass,
delete_default: bool = False,
with_prefix: Optional[str] = None,
) -> None:
"""convert a dataclass instance to tailing parser arguments"""
"""
convert a dataclass instance to tailing parser arguments.
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
building a flat namespace from a structured dataclass (see transformer_config.py for example).
"""

def argparse_name(name: str):
if name == "data":
# normally data is positional args
if name == "data" and (with_prefix is None or with_prefix == ''):
# normally data is positional args, so we don't add the -- nor the prefix
return name
if name == "_name":
# private member, skip
return None
return "--" + name.replace("_", "-")
full_name = "--" + name.replace("_", "-")
if with_prefix is not None and with_prefix != '':
# if a prefix is specified, construct the prefixed arg name
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
return full_name

def get_kwargs_from_dc(
dataclass_instance: FairseqDataclass, k: str
Expand Down Expand Up @@ -132,6 +142,10 @@ def get_kwargs_from_dc(
if field_default is not MISSING:
kwargs["default"] = field_default

# build the help with the hierarchical prefix
if with_prefix is not None and with_prefix != '' and field_help is not None:
field_help = with_prefix[2:] + ': ' + field_help

kwargs["help"] = field_help
if field_const is not None:
kwargs["const"] = field_const
Expand All @@ -145,7 +159,14 @@ def get_kwargs_from_dc(
if field_name is None:
continue
elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass):
gen_parser_from_dataclass(parser, field_type(), delete_default)
# for fields that are of type FairseqDataclass, we can recursively
# add their fields to the namespace (so we add the args from model, task, etc. to the root namespace)
prefix = None
if with_prefix is not None:
# if a prefix is specified, then we don't want to copy the subfields directly to the root namespace
# but we prefix them with the name of the current field.
prefix = field_name
gen_parser_from_dataclass(parser, field_type(), delete_default, prefix)
continue

kwargs = get_kwargs_from_dc(dataclass_instance, k)
Expand Down
87 changes: 87 additions & 0 deletions tests/test_dataclass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from argparse import ArgumentParser
from dataclasses import dataclass, field

from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass


@dataclass
class A(FairseqDataclass):
data: str = field(default="test", metadata={"help": "the data input"})
num_layers: int = field(default=200, metadata={"help": "more layers is better?"})


@dataclass
class B(FairseqDataclass):
bar: A = field(default=A())
foo: int = field(default=0, metadata={"help": "not a bar"})


@dataclass
class D(FairseqDataclass):
arch: A = field(default=A())
foo: int = field(default=0, metadata={"help": "not a bar"})


@dataclass
class C(FairseqDataclass):
data: str = field(default="test", metadata={"help": "root level data input"})
encoder: D = field(default=D())
decoder: A = field(default=A())
lr: int = field(default=0, metadata={"help": "learning rate"})


class TestDataclassUtils(unittest.TestCase):
def test_argparse_convert_basic(self):
parser = ArgumentParser()
gen_parser_from_dataclass(parser, A(), True)
args = parser.parse_args(["--num-layers", '10', "the/data/path"])
self.assertEqual(args.num_layers, 10)
self.assertEqual(args.data, "the/data/path")

def test_argparse_recursive(self):
parser = ArgumentParser()
gen_parser_from_dataclass(parser, B(), True)
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"])
self.assertEqual(args.num_layers, 10)
self.assertEqual(args.foo, 10)
self.assertEqual(args.data, "the/data/path")

def test_argparse_recursive_prefixing(self):
self.maxDiff = None
parser = ArgumentParser()
gen_parser_from_dataclass(parser, C(), True, "")
args = parser.parse_args(
[
"--encoder-arch-data",
"ENCODER_ARCH_DATA",
"--encoder-arch-num-layers",
"10",
"--encoder-foo",
"10",
"--decoder-data",
"DECODER_DATA",
"--decoder-num-layers",
"10",
"--lr",
"10",
"the/data/path",
]
)
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA")
self.assertEqual(args.encoder_arch_num_layers, 10)
self.assertEqual(args.encoder_foo, 10)
self.assertEqual(args.decoder_data, "DECODER_DATA")
self.assertEqual(args.decoder_num_layers, 10)
self.assertEqual(args.lr, 10)
self.assertEqual(args.data, "the/data/path")


if __name__ == "__main__":
unittest.main()

0 comments on commit bc1504d

Please sign in to comment.