Skip to content

Commit

Permalink
Support using vllm for running BBH
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Sep 24, 2023
1 parent d6d439b commit 2a13f1b
Showing 1 changed file with 108 additions and 130 deletions.
238 changes: 108 additions & 130 deletions eval/bbh/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import torch
import random
import vllm
import evaluate
from eval.utils import (
load_hf_lm_and_tokenizer,
Expand All @@ -16,111 +17,6 @@


exact_match = evaluate.load("exact_match")


@torch.no_grad()
def eval_hf_model(args, model, tokenizer, examples, task_prompt, save_path=None):
targets = [example["target"] for example in examples]
if save_path:
fout = open(save_path, "w")

prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for example in examples:
prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
if args.use_chat_format:
messages = [{"role": "user", "content": prompt}]
prompt = chat_formatting_function(messages, add_bos=False)
if prompt[-1] in ["\n", " "]:
prompt += "A:"
else:
prompt += " A:"
else:
prompt += "\nA:"
prompts.append(prompt)

if args.no_cot:
stop_sequnce = tokenizer.encode("\n\n", add_special_tokens=False)[-2:] # get the last token because the tokenizer may add space tokens at the start.
else:
# let's not use the stop sequence for cot now since it's too inefficient when the generation is long.
# instead, we'll do some post-processing to extract the answer.
stop_sequnce = None

outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size if args.eval_batch_size else 1,
stop_id_sequences=[[stop_sequnce]]
)

predictions = []
for example, output in zip(examples, outputs):
example["raw_output"] = output

# extract the first answer after `So the answer is` and before the next period.
# if there is no such answer, we will just use the raw output.
results = re.search(r"So the answer is (.*?)\.", output)
if results:
prediction = results.group(1).strip()
else:
# only keep the first part of the output - this is mainly for vanilla language models.
output = output.strip().split("\n\n")[0].strip()
prediction = output.strip()

example["prediction"] = prediction
predictions.append(prediction)
if save_path:
fout.write(json.dumps(example) + "\n")

assert len(predictions) == len(targets), "number of predictions and targets are not the same."
return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]


def eval_openai_chat_engine(args, examples, task_prompt, save_path=None):
targets = [example["target"] for example in examples]
instances = []
for i, example in enumerate(examples):
prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:"
instances.append({
"id": example["id"] if "id" in example else i,
"prompt": prompt,
})

if save_path:
openai_result_save_path = os.path.join(os.path.dirname(save_path), os.path.basename(save_path).split(".")[0] + "_openai_results.jsonl")

results = query_openai_chat_model(
engine=args.openai_engine,
instances=instances,
batch_size=args.eval_batch_size if args.eval_batch_size else 10,
output_path=openai_result_save_path if save_path else None,
)

outputs = [result["output"] for result in results]
assert len(outputs) == len(targets), "number of predictions and targets are not the same."

if save_path:
fout = open(save_path, "w")

predictions = []
for example, output in zip(examples, outputs):
example["raw_output"] = output
# extract the first answer after `So the answer is` and before the next period.
# if there is no such answer, we will just use the raw output.
results = re.search(r"So the answer is (.*?)\.", output)
if results:
prediction = results.group(1).strip()
else:
prediction = output.strip()
example["prediction"] = prediction
predictions.append(prediction)
if save_path:
fout.write(json.dumps(example) + "\n")

assert len(predictions) == len(targets), "number of predictions and targets are not the same."
return exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]


def main(args):
Expand Down Expand Up @@ -161,40 +57,117 @@ def main(args):
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True)

# Load model if not using OpenAI API
if args.model_name_or_path:
print("Loading model and tokenizer...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)
if args.use_vllm:
print("Loading vllm model...")
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
max_num_batched_tokens=4096,
)
else:
print("Loading model and tokenizer with huggingface...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)

performance = {}
for task_name in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"):
task_examples = all_tasks[task_name]
prompt = all_prompts[task_name]
task_prompt = all_prompts[task_name]
if args.model_name_or_path:
task_perf = eval_hf_model(
args,
model,
tokenizer,
task_examples,
prompt,
save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
)
# prepare prompts
prompts = []
chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
for example in task_examples:
prompt = task_prompt.strip() + "\n\nQ: " + example["input"]
if args.use_chat_format:
messages = [{"role": "user", "content": prompt}]
prompt = chat_formatting_function(messages, add_bos=False)
if prompt[-1] in ["\n", " "]:
prompt += "A:"
else:
prompt += " A:"
else:
prompt += "\nA:"
prompts.append(prompt)
# generate with vllm
if args.use_vllm:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=512,
stop=["\n\n"],
)
# We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
generations = model.generate(prompts, sampling_params)
prompt_to_output = {
g.prompt: g.outputs[0].text for g in generations
}
outputs = [prompt_to_output[prompt] if prompt in prompt_to_output else "" for prompt in prompts]
# generate with hf model
else:
stop_sequence = tokenizer.encode("\n\n", add_special_tokens=False)[-2:] # get the last token because the tokenizer may add space tokens at the start.
outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
temperature=0,
batch_size=args.eval_batch_size if args.eval_batch_size else 1,
stop_id_sequences=[[stop_sequence]]
)
else:
task_perf = eval_openai_chat_engine(
args,
task_examples,
prompt,
save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
instances = []
for i, example in enumerate(task_examples):
prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:"
instances.append({
"id": example["id"] if "id" in example else i,
"prompt": prompt,
})
results = query_openai_chat_model(
engine=args.openai_engine,
instances=instances,
batch_size=args.eval_batch_size if args.eval_batch_size else 10,
output_path=os.path.join(args.save_dir, "predictions", f"{task_name}_openai_prediction_cache.jsonl"),
)
performance[task_name] = task_perf
print(f"Task {task_name} - EM: {task_perf}")
outputs = [result["output"] for result in results]

targets = [example["target"] for example in task_examples]
predictions = []
for example, output in zip(task_examples, outputs):
example["raw_output"] = output

# extract the first answer after `So the answer is` and before the next period.
# if there is no such answer, we will just use the raw output.
extracted_answer = re.search(r"So the answer is (.*?)\.", output)
if extracted_answer:
prediction = extracted_answer.group(1).strip()
else:
# only keep the first part of the output - this is mainly for vanilla language models.
output = output.strip().split("\n\n")[0].strip()
prediction = output.strip()

example["prediction"] = prediction
predictions.append(prediction)

with open(os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl"), "w") as fout:
for example in task_examples:
fout.write(json.dumps(example) + "\n")

assert len(predictions) == len(targets), "number of predictions and targets are not the same."
performance[task_name] = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]

print(f"Task {task_name} - EM: {performance[task_name]}")

# save the performance
with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
performance["average_exact_match"] = sum(performance.values()) / len(performance)
print(f"Average EM: {performance['average_exact_match']}")
Expand Down Expand Up @@ -263,6 +236,11 @@ def main(args):
action="store_true",
help="If given, we're evaluating a 4-bit quantized GPTQ model."
)
parser.add_argument(
"--use_vllm",
action="store_true",
help="If given, we will use the vllm library, which will likely increase the inference throughput."
)
parser.add_argument(
"--use_chat_format",
action="store_true",
Expand Down

0 comments on commit 2a13f1b

Please sign in to comment.