From dbe963717c69c69a5409a3b427544481ade5e70c Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Fri, 20 Jul 2018 13:18:31 -0700 Subject: [PATCH] option to print language model words and their log probs during evaluation --- eval_lm.py | 16 ++++++++++++++++ fairseq/options.py | 2 ++ 2 files changed, 18 insertions(+) diff --git a/eval_lm.py b/eval_lm.py index f6102b16aa..79246d230d 100644 --- a/eval_lm.py +++ b/eval_lm.py @@ -60,8 +60,10 @@ def main(args): if args.remove_bpe is not None: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) + bpe_len = len(bpe_cont) else: bpe_toks = None + bpe_len = 0 with progress_bar.build_progress_bar(args, itr) as t: results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) @@ -85,6 +87,20 @@ def main(args): pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum() count += pos_scores.numel() - skipped_toks + + if args.output_word_probs: + w = '' + word_prob = [] + for i in range(len(hypo['tokens'])): + w_ind = hypo['tokens'][i].item() + w += task.dictionary[w_ind] + if bpe_toks is not None and w_ind in bpe_toks: + w = w[:-bpe_len] + else: + word_prob.append((w, pos_scores[i].item())) + w = '' + print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) + wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) diff --git a/fairseq/options.py b/fairseq/options.py index ca76a787c9..a834e0b732 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -249,6 +249,8 @@ def add_common_eval_args(group): def add_eval_lm_args(parser): group = parser.add_argument_group('LM Evaluation') add_common_eval_args(group) + group.add_argument('--output-word-probs', action='store_true', + help='if set, outputs words and their predicted log probabilities to standard output') def add_generation_args(parser):