Skip to content

Commit

Permalink
Refine the prompting and parsing for BBH.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Oct 19, 2023
1 parent 4ab022a commit 720159a
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions eval/bbh/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,25 @@ def main(args):
task_examples = all_tasks[task_name]
task_prompt = all_prompts[task_name]
if args.model_name_or_path:
# prepare prompts
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for example in task_examples:
prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
if args.use_chat_format:
# prepare prompts
if args.use_chat_format:
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function)
for example in task_examples:
prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
messages = [{"role": "user", "content": prompt}]
prompt = chat_formatting_function(messages, add_bos=False)
if prompt[-1] in ["\n", " "]:
prompt += "A:"
else:
prompt += " A:"
else:
prompt += "\nA:"
prompts.append(prompt)
prompt += "A:" if prompt[-1] in ["\n", " "] else " A:"
prompts.append(prompt)
else:
prompts = [task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:" for example in task_examples]

# generate with vllm
if args.use_vllm:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=512,
stop=["\n\n"],
stop=["\n\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
# We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
generations = model.generate(prompts, sampling_params)
Expand All @@ -122,7 +120,7 @@ def main(args):
max_new_tokens=512,
temperature=0,
batch_size=args.eval_batch_size if args.eval_batch_size else 1,
stop_id_sequences=[[stop_sequence]]
stop_id_sequences=[[stop_sequence]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
else:
instances = []
Expand All @@ -145,18 +143,14 @@ def main(args):
for example, output in zip(task_examples, outputs):
example["raw_output"] = output

# extract the first answer after `So the answer is` and before the next period.
# extract the first answer after `the answer is` and before the next period.
# if there is no such answer, we will just use the raw output.
extracted_answer = re.search(r"So the answer is (.*?)\.", output)
extracted_answer = re.search(r"[t|T]he answer is (.*?)\.", output)
if extracted_answer:
prediction = extracted_answer.group(1).strip()
example["prediction"] = extracted_answer.group(1).strip()
else:
# only keep the first part of the output - this is mainly for vanilla language models.
output = output.strip().split("\n\n")[0].strip()
prediction = output.strip()

example["prediction"] = prediction
predictions.append(prediction)
example["prediction"] = output.strip()
predictions.append(example["prediction"])

with open(os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl"), "w") as fout:
for example in task_examples:
Expand Down

0 comments on commit 720159a

Please sign in to comment.