Skip to content

Commit

Permalink
Support using vllm for running codex humaneval.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhongw committed Sep 24, 2023
1 parent 87911fd commit d6d439b
Showing 1 changed file with 67 additions and 37 deletions.
104 changes: 67 additions & 37 deletions eval/codex_humaneval/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import random
import torch
import vllm
from eval.utils import (
generate_completions,
load_hf_lm_and_tokenizer,
Expand Down Expand Up @@ -39,43 +40,64 @@ def main(args):
prompts = [example["prompt"] for example in test_data]

if args.model_name_or_path:
print("Loading model and tokenizer...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
# device map is determined by the number of gpus available.
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)

# these stop sequences are those mentioned in the codex paper.
stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]
# Because many tokenizers will treat the word after space differently from the original word alone,
# to be consistent, we add a space before tokenization and remove it after tokenization.
stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
outputs_per_sampling_iter = []
for sampling_iter in range(args.unbiased_sampling_size_n):
print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}")
samping_outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=stop_sequences,
num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom.
do_sample=True, # if only pass@1 is evaluated, we do greedy decoding.
if args.use_vllm:
model = vllm.LLM(
model=args.model_name_or_path,
tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
)
sampling_params = vllm.SamplingParams(
n=args.unbiased_sampling_size_n,
temperature=args.temperature,
top_p=0.95,
temperature=args.temperature,
max_tokens=512,
stop=["\nclass", "\ndef", "\n#", "\nif", "\nprint"]
)
outputs_per_sampling_iter.append(samping_outputs)
# regroup the outputs to match the number of test data.
outputs = []
for i in range(len(prompts)):
for j in range(args.unbiased_sampling_size_n):
outputs.append(outputs_per_sampling_iter[j][i])
generations = model.generate(prompts, sampling_params)
outputs = [output.text for it in generations for output in it.outputs]
# Note: vllm will ignore the first space in the generation, because the processing of _token.
# This is not a problem for chat, but for codex, we need to keep the first space.
# So, we manually add a space at the beginning.
outputs = [" " + output for output in outputs]
else:
print("Loading model and tokenizer...")
model, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=args.model_name_or_path,
tokenizer_name_or_path=args.tokenizer_name_or_path,
load_in_8bit=args.load_in_8bit,
# device map is determined by the number of gpus available.
device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
gptq_model=args.gptq,
use_fast_tokenizer=not args.use_slow_tokenizer,
)

# these stop sequences are those mentioned in the codex paper.
stop_sequences = ["\nclass", "\ndef", "\n#", "\nif", "\nprint"]
# Because many tokenizers will treat the word after space differently from the original word alone,
# to be consistent, we add a space before tokenization and remove it after tokenization.
stop_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
outputs_per_sampling_iter = []
for sampling_iter in range(args.unbiased_sampling_size_n):
print(f"Sampling iter: {sampling_iter} / {args.unbiased_sampling_size_n}")
samping_outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=stop_sequences,
num_return_sequences=1, # we don't use the hf num_return_sequences, because otherwise the real batch size will be multiplied by it and often cause oom.
do_sample=True, # if only pass@1 is evaluated, we do greedy decoding.
top_p=0.95,
temperature=args.temperature,
)
outputs_per_sampling_iter.append(samping_outputs)
# regroup the outputs to match the number of test data.
outputs = []
for i in range(len(prompts)):
for j in range(args.unbiased_sampling_size_n):
outputs.append(outputs_per_sampling_iter[j][i])
else:
instances = [{
"id": examle["task_id"],
Expand Down Expand Up @@ -108,6 +130,7 @@ def main(args):
sample_file=prediction_save_path,
k=args.eval_pass_at_ks,
problems={example["task_id"]: example for example in test_data},
n_workers=64
)

print(pass_at_k_results)
Expand Down Expand Up @@ -187,11 +210,18 @@ def main(args):
parser.add_argument(
"--load_in_8bit",
action="store_true",
help="Load model in 8bit mode, which will reduce memory and speed up inference.")
help="Load model in 8bit mode, which will reduce memory and speed up inference."
)
parser.add_argument(
"--gptq",
action="store_true",
help="If given, we're evaluating a 4-bit quantized GPTQ model.")
help="If given, we're evaluating a 4-bit quantized GPTQ model."
)
parser.add_argument(
"--use_vllm",
action="store_true",
help="If given, we will use the vllm library, which will likely increase the inference throughput."
)
parser.add_argument(
"--use_chat_format",
action="store_true",
Expand Down

0 comments on commit d6d439b

Please sign in to comment.