forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changelog: - 97b58b4: add Transformer model from Vaswani et al. (2017) - b2374e5: faster Transformer inference with improved caching - 2d27ae0: simulate large mini-batch training with delayed updates (`--update-freq`) - 7ee1d28: add FP16 training support (`--fp16`) - 2a84f46: faster inference by removing completed sentences from the batch - 663fd80: batched interactive generation - 4c2ef2d: add language modeling / gated convolutional model from Dauphin et al. (2017) - b59815b: add Hierarchical Neural Story Generation model from Fan et al. (2018) - ff68a9e: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
- Loading branch information
Showing
74 changed files
with
5,297 additions
and
1,692 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python3 -u | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq import data, options, progress_bar, tasks, utils | ||
from fairseq.meters import StopwatchMeter, TimeMeter | ||
from fairseq.sequence_scorer import SequenceScorer | ||
|
||
|
||
def main(args): | ||
assert args.path is not None, '--path required for evaluation!' | ||
|
||
if args.tokens_per_sample is None: | ||
args.tokens_per_sample = 1024 | ||
print(args) | ||
|
||
use_cuda = torch.cuda.is_available() and not args.cpu | ||
|
||
# Load dataset splits | ||
task = tasks.setup_task(args) | ||
task.load_dataset(args.gen_subset) | ||
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) | ||
|
||
# Load ensemble | ||
print('| loading model(s) from {}'.format(args.path)) | ||
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task) | ||
|
||
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) | ||
for model in models: | ||
model.make_generation_fast_() | ||
|
||
itr = data.EpochBatchIterator( | ||
dataset=task.dataset(args.gen_subset), | ||
max_sentences=args.max_sentences or 4, | ||
max_positions=model.max_positions(), | ||
num_shards=args.num_shards, | ||
shard_id=args.shard_id, | ||
).next_epoch_itr(shuffle=False) | ||
|
||
gen_timer = StopwatchMeter() | ||
scorer = SequenceScorer(models, task.target_dictionary) | ||
if use_cuda: | ||
scorer.cuda() | ||
|
||
score_sum = 0. | ||
count = 0 | ||
with progress_bar.build_progress_bar(args, itr) as t: | ||
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) | ||
wps_meter = TimeMeter() | ||
for _, src_tokens, __, hypos in results: | ||
for hypo in hypos: | ||
pos_scores = hypo['positional_scores'] | ||
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) | ||
if inf_scores.any(): | ||
print('| Skipping tokens with inf scores:', | ||
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) | ||
pos_scores = pos_scores[(~inf_scores).nonzero()] | ||
score_sum += pos_scores.sum() | ||
count += pos_scores.numel() | ||
wps_meter.update(src_tokens.size(0)) | ||
t.log({'wps': round(wps_meter.avg)}) | ||
|
||
avg_nll_loss = -score_sum / count | ||
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) | ||
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = options.get_eval_lm_parser() | ||
args = options.parse_args_and_arch(parser) | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
Sample data processing scripts for the FAIR Sequence-to-Sequence Toolkit | ||
|
||
These scripts provide an example of pre-processing data for the Language Modeling task. | ||
|
||
# prepare-wikitext-103.sh | ||
|
||
Provides an example of pre-processing for [WikiText-103 language modeling task](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset): | ||
|
||
Example usage: | ||
``` | ||
$ cd examples/language_model/ | ||
$ bash prepare-wikitext-103.sh | ||
$ cd ../.. | ||
# Binarize the dataset: | ||
$ TEXT=examples/language_model/wikitext-103 | ||
$ python preprocess.py --only-source \ | ||
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \ | ||
--destdir data-bin/wikitext-103 | ||
# Train the model: | ||
# If it runs out of memory, try to reduce max-tokens and max-target-positions | ||
$ mkdir -p checkpoints/wikitext-103 | ||
$ python train.py --task language_modeling data-bin/wikitext-103 \ | ||
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \ | ||
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ | ||
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ | ||
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 | ||
# Evaluate: | ||
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/bin/bash | ||
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh | ||
|
||
URLS=( | ||
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" | ||
) | ||
FILES=( | ||
"wikitext-103-v1.zip" | ||
) | ||
|
||
for ((i=0;i<${#URLS[@]};++i)); do | ||
file=${FILES[i]} | ||
if [ -f $file ]; then | ||
echo "$file already exists, skipping download" | ||
else | ||
url=${URLS[i]} | ||
wget "$url" | ||
if [ -f $file ]; then | ||
echo "$url successfully downloaded." | ||
else | ||
echo "$url not successfully downloaded." | ||
exit -1 | ||
fi | ||
if [ ${file: -4} == ".tgz" ]; then | ||
tar zxvf $file | ||
elif [ ${file: -4} == ".tar" ]; then | ||
tar xvf $file | ||
elif [ ${file: -4} == ".zip" ]; then | ||
unzip $file | ||
fi | ||
fi | ||
done | ||
cd .. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
FAIR Sequence-to-Sequence Toolkit for Story Generation | ||
|
||
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset. | ||
|
||
The dataset can be downloaded like this: | ||
|
||
``` | ||
curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvjf - | ||
``` | ||
|
||
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833, where only the first 1000 words of each story are modeled. | ||
|
||
|
||
Example usage: | ||
``` | ||
# Binarize the dataset: | ||
$ TEXT=examples/stories/writingPrompts | ||
$ python preprocess.py --source-lang wp_source --target-lang wp_target \ | ||
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ | ||
--destdir data-bin/writingPrompts --thresholdtgt 10 --thresholdsrc 10 | ||
# Train the model: | ||
$ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-norm 0.1 --max-tokens 1500 --lr-scheduler reduce_lr_on_plateau --decoder-attention True --encoder-attention False --criterion label_smoothed_cross_entropy --weight-decay .0000001 --label-smoothing 0 --source-lang wp_source --target-lang wp_target --gated-attention True --self-attention True --project-input True --pretrained False | ||
# Train a fusion model: | ||
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint | ||
# Generate: | ||
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 | ||
``` |
Oops, something went wrong.