Skip to content

Commit

Permalink
Refine the prompting and parsing for GSM.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Oct 19, 2023
1 parent 720159a commit bb45970
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions eval/gsm/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,16 @@ def main(args):
else:
prompt_prefix = "Answer the following question.\n\n"

prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for example in test_data:
prompt = prompt_prefix + "Question: " + example["question"].strip()
if args.use_chat_format:
messages = [{"role": "user", "content": prompt}]
if args.use_chat_format:
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function)
for example in test_data:
messages = [{"role": "user", "content": prompt_prefix + "Question: " + example["question"].strip()}]
prompt = chat_formatting_function(messages, add_bos=False)
if prompt[-1] in ["\n", " "]:
prompt += "Answer:"
else:
prompt += " Answer:"
else:
prompt += "\nAnswer:"
prompts.append(prompt)
prompt += "Answer:" if prompt[-1] in ["\n", " "] else " Answer:"
prompts.append(prompt)
else:
prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data]

if args.model_name_or_path:
print("Loading model and tokenizer...")
Expand All @@ -89,7 +85,7 @@ def main(args):
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=512,
stop=["\n"],
stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
# We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
generations = model.generate(prompts, sampling_params)
Expand All @@ -113,7 +109,7 @@ def main(args):
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=[[new_line_token]],
stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
do_sample=False,
)
else:
Expand Down

0 comments on commit bb45970

Please sign in to comment.