Skip to content

Commit

Permalink
update humaneval evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Sep 19, 2024
1 parent 93885d2 commit f37a955
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 257 deletions.
1 change: 1 addition & 0 deletions evalueate/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
torch_dtype = ['bf16', 'fp16']
task_name = ['humaneval']


@dataclass
class EvaluateArgs:
"""
Expand Down
43 changes: 43 additions & 0 deletions evalueate/base_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

class TaskUtils:
def __init__(self):
self.IMPORT_HELPER = {
"python": [
"import math",
"import re",
"import sys",
"import copy",
"import datetime",
"import itertools",
"import collections",
"import heapq",
"import functools",
"import hashlib",
"import numpy",
"import numpy as np",
"import string",
"from typing import *",
"from collections import *",
]
}

@staticmethod
def build_instruction(example):
"""
根据模型构建合适的指令
"""
return example['prompt']

@staticmethod
def generation_code_process(example):
"""
对生成的代码提取函数部分 及 设置import等操作
"""
pass

@staticmethod
def return_test_code(sample):
"""
返回最终进行运行测试code
"""
pass
81 changes: 15 additions & 66 deletions evalueate/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,12 @@
import json
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
from tqdm import tqdm

from execution import check_correctness

IMPORT_HELPER = {
"python": [
"import math",
"import re",
"import sys",
"import copy",
"import datetime",
"import itertools",
"import collections",
"import heapq",
"import functools",
"import hashlib",
"import numpy",
"import numpy as np",
"import string",
"from typing import *",
"from collections import *",
]
}


def generation_process(code, test):
code_ = []
skip_rest = False # 新增标志位,用于跳过if __name__ == "__main__"及其后面的内容
for line in code.split("\n"):
if skip_rest:
continue # 如果遇到if __name__ == "__main__",跳过该行及其后面的所有内容
if any(sub in line for sub in ["if __name__ == \"__main__\":", "if __name__ == '__main__':"]):
skip_rest = True # 设置标志位,表示需要跳过后续内容
continue
if "def " in line and line[0] != ' ' and line[0] != '\t':
code_.append("def " + line.split("def ")[1])
continue
if "class" in line and line.strip().endswith(":"):
code_.append(line)
continue
if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
continue
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
return test_setup + code

# if code.startswith("def"):
# test_string = test_setup + code + "\n\n" + test + "\n"
# else:
# test_string = test_setup + prompt + code + "\n" + test


def process_humaneval_test(sample):
"""
Processes a sample for evaluation.
return: the str of test code
"""
test = sample["test"]
code = sample["generation"]
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
test_string = test_setup + code + "\n" + test + "\n"
return test_string


def stream_jsonl_all(filename: str):
"""
Expand All @@ -87,7 +25,8 @@ def evaluate_functional_correctness(
input_file: str = None,
n_workers: int = 32,
timeout: float = 3.0,
k: int = 1
k: int = 1,
save_logs_path = './logs.jsonl'
):
"""
Evaluates the functional correctness of a model.
Expand Down Expand Up @@ -116,16 +55,22 @@ def evaluate_functional_correctness(
result = future.result()
results[result["task_id"]].append(result)
# Calculate pass@k.
total, correct = [], []
total, correct, logs = [], [], []
for result in results.values():
passed = [r["passed"] for r in result]
res = [{r['task_id']:r["result"]} for r in result]
logs.append(res)
total.append(len(passed))
correct.append(sum(passed))
total = np.array(total)
correct = np.array(correct)

pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()}
print(pass_at_k)

with open(save_logs_path, 'w', encoding='utf-8') as fw:
for ex in logs:
fw.write(json.dumps(ex) + '\n')
print(f"execute logs were saved at {save_logs_path}")

return pass_at_k

Expand All @@ -151,3 +96,7 @@ def estimator(n: int, c: int, k: int) -> float:
num_samples_it = iter(num_samples)

return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])


def evaluate_only(args):
pass
56 changes: 40 additions & 16 deletions evalueate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,13 @@
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
import os
from utils import language_settings, extract_generation_code
from evaluation import evaluate_functional_correctness
from args import EvaluateArgs


def build_instruction(prompt: str):
"""
根据模型构建合适的指令
"""
return prompt
def generate_one(example, prompt, tokenizer, model, args):


def generate_one(example, lang, tokenizer, model, args):
# prompt = build_instruction(language_settings[lang]['full_name'], example['prompt'])
# 适配模型Chat Template
# inputs = tokenizer.apply_chat_template(
# [{'role': 'user', 'content': prompt}],
# return_tensors="pt",
# add_generation_prompt=True
# ).to(model.device)
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids(
Expand All @@ -42,4 +28,42 @@ def generate_one(example, lang, tokenizer, model, args):
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
example['output'] = output

return extract_generation_code(example, lang_code=lang)
return extract_generation_code(example)


def generate_main(args):
model_name_or_path = args.model_name_or_path
saved_path = args.output_path

print("model", model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path))
torch_dtype = torch.bfloat16 if args.torch_dtype == 'bf16' else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
)
model.eval()
examples = [json.loads(x) for x in open(args.data_file) if x.strip()]
print("Read {} examples for evaluation over.".format(len(examples)))

generated_examples = []
for ex in tqdm(examples, desc='Generating'):
gen_example = generate_one(ex, args.language, tokenizer, model, args)
generated_examples.append(gen_example)

print("Generate all over!!!")
with open(saved_path, 'w', encoding='utf-8') as fw:
for ex in generated_examples:
fw.write(json.dumps(ex) + '\n')
print("Save {} processed examples into {} over!".format(len(generated_examples), saved_path))

result = evaluate_functional_correctness(
input_file=saved_path,
n_workers=8,
timeout=3.0,
k=1
)
print(result, model_name_or_path)
96 changes: 0 additions & 96 deletions evalueate/humaneval.py

This file was deleted.

17 changes: 17 additions & 0 deletions evalueate/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from transformers import HfArgumentParser
import os
from args import EvaluateArgs
from task import humaneval

parser = HfArgumentParser((EvaluateArgs,))
args = parser.parse_args_into_dataclasses()[0]
os.environ["TOKENIZERS_PARALLELISM"] = "false"

TASKS = {
"humaneval": humaneval.HumanEval
}

task = TASKS[args.task_name]

if not args.generate_only:
generate_main(args)
Loading

0 comments on commit f37a955

Please sign in to comment.