forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This is a precursor to D29232595 The current behaviour to convert a dataclass to a namespace is that all the fields from all DCs in the field hierarchy are flattened at the top. This is also the legacy behaviour with `add_args`. This is kind of cumbersome to build reusable Dataclasses as we need to make sure that each field has a unique name. In the case of Transformer for instance, we have a Decoder and Encoder config that share a large part of their fields (embed_dim, layers, etc.). We can build a single dataclass for this that can be reused and extended in other implementations. To be then able to have a flat namespace, instead of adding all subfields as is to the root namespace, we introduce the name of the field as prefix to the arg in the namespace. So: `model.decoder.embed_dim` becomes `decoder_embed_dim` and `model.encoder.embed_dim` becomes `encoder_embed_dim`. Reviewed By: myleott, dianaml0 Differential Revision: D29521386 fbshipit-source-id: f4bef036f0eeb620c6d8709ce97f96ae288848ef
- Loading branch information
1 parent
7ebdc24
commit bc1504d
Showing
2 changed files
with
113 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from argparse import ArgumentParser | ||
from dataclasses import dataclass, field | ||
|
||
from fairseq.dataclass import FairseqDataclass | ||
from fairseq.dataclass.utils import gen_parser_from_dataclass | ||
|
||
|
||
@dataclass | ||
class A(FairseqDataclass): | ||
data: str = field(default="test", metadata={"help": "the data input"}) | ||
num_layers: int = field(default=200, metadata={"help": "more layers is better?"}) | ||
|
||
|
||
@dataclass | ||
class B(FairseqDataclass): | ||
bar: A = field(default=A()) | ||
foo: int = field(default=0, metadata={"help": "not a bar"}) | ||
|
||
|
||
@dataclass | ||
class D(FairseqDataclass): | ||
arch: A = field(default=A()) | ||
foo: int = field(default=0, metadata={"help": "not a bar"}) | ||
|
||
|
||
@dataclass | ||
class C(FairseqDataclass): | ||
data: str = field(default="test", metadata={"help": "root level data input"}) | ||
encoder: D = field(default=D()) | ||
decoder: A = field(default=A()) | ||
lr: int = field(default=0, metadata={"help": "learning rate"}) | ||
|
||
|
||
class TestDataclassUtils(unittest.TestCase): | ||
def test_argparse_convert_basic(self): | ||
parser = ArgumentParser() | ||
gen_parser_from_dataclass(parser, A(), True) | ||
args = parser.parse_args(["--num-layers", '10', "the/data/path"]) | ||
self.assertEqual(args.num_layers, 10) | ||
self.assertEqual(args.data, "the/data/path") | ||
|
||
def test_argparse_recursive(self): | ||
parser = ArgumentParser() | ||
gen_parser_from_dataclass(parser, B(), True) | ||
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"]) | ||
self.assertEqual(args.num_layers, 10) | ||
self.assertEqual(args.foo, 10) | ||
self.assertEqual(args.data, "the/data/path") | ||
|
||
def test_argparse_recursive_prefixing(self): | ||
self.maxDiff = None | ||
parser = ArgumentParser() | ||
gen_parser_from_dataclass(parser, C(), True, "") | ||
args = parser.parse_args( | ||
[ | ||
"--encoder-arch-data", | ||
"ENCODER_ARCH_DATA", | ||
"--encoder-arch-num-layers", | ||
"10", | ||
"--encoder-foo", | ||
"10", | ||
"--decoder-data", | ||
"DECODER_DATA", | ||
"--decoder-num-layers", | ||
"10", | ||
"--lr", | ||
"10", | ||
"the/data/path", | ||
] | ||
) | ||
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA") | ||
self.assertEqual(args.encoder_arch_num_layers, 10) | ||
self.assertEqual(args.encoder_foo, 10) | ||
self.assertEqual(args.decoder_data, "DECODER_DATA") | ||
self.assertEqual(args.decoder_num_layers, 10) | ||
self.assertEqual(args.lr, 10) | ||
self.assertEqual(args.data, "the/data/path") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |