Skip to content

Commit

Permalink
Refine the prompting and parsing for TydiQA.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Oct 19, 2023
1 parent bb45970 commit b36f5b6
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions eval/tydiqa/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,7 @@ def main(args):
if args.use_chat_format:
messages = [{"role": "user", "content": prompt}]
prompt = chat_formatting_function(messages, add_bos=False)
if prompt[-1] in ["\n", " "]:
prompt += a_template
else:
prompt += " " + a_template
prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
else:
prompt += a_template
prompts.append(prompt)
Expand All @@ -184,7 +181,7 @@ def main(args):
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=50,
stop=["\n"],
stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
# We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
generations = model.generate(prompts, sampling_params)
Expand All @@ -200,7 +197,7 @@ def main(args):
prompts=prompts,
max_new_tokens=50,
batch_size=args.eval_batch_size,
stop_id_sequences=[[new_line_token]]
stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
# remove unnecessary space
outputs = [output.strip() for output in outputs]
Expand Down

0 comments on commit b36f5b6

Please sign in to comment.