Skip to content

Commit

Permalink
Add custom dataset ability for transformer (PaddlePaddle#711)
Browse files Browse the repository at this point in the history
* add custom dataset ability for transformer

* update

* list support

* help update

Co-authored-by: Zeyu Chen <[email protected]>
  • Loading branch information
FrostML and ZeyuChen authored Jul 10, 2021
1 parent 95bb34c commit 39d6cf6
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def parse_args():
"--profile",
action="store_true",
help="Whether to profile the performance using newstest2014 dataset. ")
parser.add_argument(
"--test_file",
default=None,
type=str,
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -178,6 +184,7 @@ def do_predict(args):
args.benchmark = False
if ARGS.batch_size:
args.infer_batch_size = ARGS.batch_size
args.test_file = ARGS.test_file
pprint(args)

do_predict(args)
9 changes: 8 additions & 1 deletion examples/machine_translation/transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def parse_args():
action="store_true",
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
)
parser.add_argument(
"--test_file",
default=None,
type=str,
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -106,7 +112,8 @@ def do_predict(args):
yaml_file = ARGS.config
with open(yaml_file, 'rt') as f:
args = AttrDict(yaml.safe_load(f))
pprint(args)
args.benchmark = ARGS.benchmark
args.test_file = ARGS.test_file
pprint(args)

do_predict(args)
20 changes: 18 additions & 2 deletions examples/machine_translation/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,18 @@ def min_max_filer(data, max_len, min_len=0):


def create_data_loader(args, places=None):
datasets = load_dataset('wmt14ende', splits=('train', 'dev'))
if args.train_file is not None and args.dev_file is not None:
datasets = load_dataset(
'wmt14ende',
datafiles={'train': args.train_file,
'dev': args.dev_file},
splits=('train', 'dev'))
elif args.train_file is None and args.dev_file is None:
datasets = load_dataset('wmt14ende', splits=('train', 'dev'))
else:
raise ValueError(
"--train_file and --dev_file must be both or neither set. ")

if not args.benchmark:
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"])
else:
Expand Down Expand Up @@ -92,7 +103,12 @@ def convert_samples(sample):


def create_infer_loader(args):
dataset = load_dataset('wmt14ende', splits=('test'))
if args.test_file is not None:
dataset = load_dataset(
'wmt14ende', datafiles={'test': args.test_file}, splits=('test'))
else:
dataset = load_dataset('wmt14ende', splits=('test'))

if not args.benchmark:
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
else:
Expand Down
18 changes: 17 additions & 1 deletion examples/machine_translation/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ def parse_args():
default=None,
type=int,
help="The maximum iteration for training. ")
parser.add_argument(
"--train_file",
nargs='+',
default=None,
type=str,
help="The files for training, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to train. "
)
parser.add_argument(
"--dev_file",
nargs='+',
default=None,
type=str,
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -247,9 +261,11 @@ def do_train(args):
yaml_file = ARGS.config
with open(yaml_file, 'rt') as f:
args = AttrDict(yaml.safe_load(f))
pprint(args)
args.benchmark = ARGS.benchmark
if ARGS.max_iter:
args.max_iter = ARGS.max_iter
args.train_file = ARGS.train_file
args.dev_file = ARGS.dev_file
pprint(args)

do_train(args)

0 comments on commit 39d6cf6

Please sign in to comment.