Skip to content

Commit

Permalink
Support batch eval for trutufulqa mc setup.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Oct 19, 2023
1 parent 677cbd0 commit 6b038fc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions eval/truthfulqa/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def run_hf_model(questions, model, tokenizer, tag, preset="qa", batch_size=1, ma
return questions


def run_hf_model_mc(questions, model, tokenizer, tag, preset='qa'):
def run_hf_model_mc(questions, model, tokenizer, tag, batch_size=1, preset='qa'):
"""Runs multiple-choice metrics for autoregressive HuggingFace models (GPT-2, GPT-Neo)"""

set_columns(tag, questions)
Expand Down Expand Up @@ -224,7 +224,7 @@ def run_hf_model_mc(questions, model, tokenizer, tag, preset='qa'):
# candidate completions
examples.append({"prompt": prompt, "completions": ref_true + ref_false})

all_scores = score_completions(model, tokenizer, examples, batch_size=args.eval_batch_size, aggregation="sum")
all_scores = score_completions(model, tokenizer, examples, batch_size=batch_size, aggregation="sum")
assert len(all_scores) == len(examples)

for idx, example in zip(questions.index, examples):
Expand Down Expand Up @@ -274,7 +274,7 @@ def main(args):
)
if "mc" in args.metrics:
print("Running multiple-choice classification!")
run_hf_model_mc(questions, model, tokenizer, tag=args.model_name_or_path, preset=args.preset)
run_hf_model_mc(questions, model, tokenizer, tag=args.model_name_or_path, batch_size=args.eval_batch_size, preset=args.preset)
elif args.openai_engine:
# gpt-3 language models
if args.openai_engine in ['ada', 'babbage', 'curie', 'davinci', 'text-davinci-003', 'text-davinci-002', 'code-davinci-002']:
Expand Down

0 comments on commit 6b038fc

Please sign in to comment.