From bb45970b518cf8bff169bb9d1374dd840f8a943a Mon Sep 17 00:00:00 2001 From: yizhongw Date: Sun, 1 Oct 2023 12:53:19 -0700 Subject: [PATCH] Refine the prompting and parsing for GSM. --- eval/gsm/run_eval.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/eval/gsm/run_eval.py b/eval/gsm/run_eval.py index 1370ead0e..734025625 100644 --- a/eval/gsm/run_eval.py +++ b/eval/gsm/run_eval.py @@ -61,20 +61,16 @@ def main(args): else: prompt_prefix = "Answer the following question.\n\n" - prompts = [] - chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None - for example in test_data: - prompt = prompt_prefix + "Question: " + example["question"].strip() - if args.use_chat_format: - messages = [{"role": "user", "content": prompt}] + if args.use_chat_format: + prompts = [] + chat_formatting_function = dynamic_import_function(args.chat_formatting_function) + for example in test_data: + messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}] prompt = chat_formatting_function(messages, add_bos=False) - if prompt[-1] in ["\n", " "]: - prompt += "Answer:" - else: - prompt += " Answer:" - else: - prompt += "\nAnswer:" - prompts.append(prompt) + prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:" + prompts.append(prompt) + else: + prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data] if args.model_name_or_path: print("Loading model and tokenizer...") @@ -89,7 +85,7 @@ def main(args): sampling_params = vllm.SamplingParams( temperature=0, max_tokens=512, - stop=["\n"], + stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. ) # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long) generations = model.generate(prompts, sampling_params) @@ -113,7 +109,7 @@ def main(args): prompts=prompts, max_new_tokens=512, batch_size=args.eval_batch_size, - stop_id_sequences=[[new_line_token]], + stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop. do_sample=False, ) else: