Skip to content

Commit

Permalink
Support using vllm for running GSM
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Sep 24, 2023
1 parent 2a13f1b commit 9fe582e
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions eval/gsm/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import random
import torch
import vllm
import evaluate
from eval.utils import (
generate_completions,
Expand Down Expand Up @@ -77,23 +78,44 @@ def main(args):

if args.model_name_or_path:
print("Loading model and tokenizer...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)
new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=[[new_line_token]]
)
if args.use_vllm:
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
max_num_batched_tokens=4096,
)
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=512,
stop=["\n"],
)
# We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
generations = model.generate(prompts, sampling_params)
prompt_to_output = {
g.prompt: g.outputs[0].text for g in generations
}
outputs = [prompt_to_output[prompt] if prompt in prompt_to_output else "" for prompt in prompts]
else:
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)
new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=[[new_line_token]],
do_sample=False,
)
else:
instances = [{"id": prompt, "prompt": prompt} for _, prompt in enumerate(prompts)]
results = query_openai_chat_model(
Expand Down Expand Up @@ -204,6 +226,11 @@ def main(args):
action="store_true",
help="If given, we're evaluating a 4-bit quantized GPTQ model."
)
parser.add_argument(
"--use_vllm",
action="store_true",
help="If given, we will use the vllm library, which will likely increase the inference throughput."
)
parser.add_argument(
"--use_chat_format",
action="store_true",
Expand Down

0 comments on commit 9fe582e

Please sign in to comment.