From 6b038fc7ce3beee327696b22660f69879d10055b Mon Sep 17 00:00:00 2001 From: yizhongw Date: Tue, 3 Oct 2023 02:48:38 -0700 Subject: [PATCH] Support batch eval for trutufulqa mc setup. --- eval/truthfulqa/run_eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/truthfulqa/run_eval.py b/eval/truthfulqa/run_eval.py index 926023faa..203291a38 100644 --- a/eval/truthfulqa/run_eval.py +++ b/eval/truthfulqa/run_eval.py @@ -192,7 +192,7 @@ def run_hf_model(questions, model, tokenizer, tag, preset="qa", batch_size=1, ma return questions -def run_hf_model_mc(questions, model, tokenizer, tag, preset='qa'): +def run_hf_model_mc(questions, model, tokenizer, tag, batch_size=1, preset='qa'): """Runs multiple-choice metrics for autoregressive HuggingFace models (GPT-2, GPT-Neo)""" set_columns(tag, questions) @@ -224,7 +224,7 @@ def run_hf_model_mc(questions, model, tokenizer, tag, preset='qa'): # candidate completions examples.append({"prompt": prompt, "completions": ref_true + ref_false}) - all_scores = score_completions(model, tokenizer, examples, batch_size=args.eval_batch_size, aggregation="sum") + all_scores = score_completions(model, tokenizer, examples, batch_size=batch_size, aggregation="sum") assert len(all_scores) == len(examples) for idx, example in zip(questions.index, examples): @@ -274,7 +274,7 @@ def main(args): ) if "mc" in args.metrics: print("Running multiple-choice classification!") - run_hf_model_mc(questions, model, tokenizer, tag=args.model_name_or_path, preset=args.preset) + run_hf_model_mc(questions, model, tokenizer, tag=args.model_name_or_path, batch_size=args.eval_batch_size, preset=args.preset) elif args.openai_engine: # gpt-3 language models if args.openai_engine in ['ada', 'babbage', 'curie', 'davinci', 'text-davinci-003', 'text-davinci-002', 'code-davinci-002']: