Skip to content

Commit

Permalink
update evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Sep 17, 2024
1 parent d8a7a2c commit 2ef23e9
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 53 deletions.
8 changes: 7 additions & 1 deletion evalueate/args.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from dataclasses import dataclass, field
from typing import Optional

torch_dtype = ['bf16', 'fp16']


@dataclass
class EvaluateArgs:
"""
配置Evaluate的参数
"""
max_new_tokens: int = 100
max_length: int = 256
torch_dtype: str = 'fp16'
do_sample: bool = False
top_p: float = 0.95
temperature: int = 1
model_name_or_path: str = './'
output_path: str = './'
temp_dir: str = 'tmp'
data_file: str = ''
107 changes: 107 additions & 0 deletions evalueate/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor

from tqdm import tqdm

from .execution import check_correctness

IMPORT_HELPER = {
"python": [
"import math",
"import re",
"import sys",
"import copy",
"import datetime",
"import itertools",
"import collections",
"import heapq",
"import functools",
"import hashlib",
"import numpy",
"import numpy as np",
"import string",
"from typing import *",
"from collections import *",
]
}


def process_humaneval_test(sample):
"""
Processes a sample for evaluation.
return: the str of test code
"""
test = sample["test"]
code = sample["generation"]
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
test_string = test_setup + code + "\n" + test + "\n"
return test_string


def stream_jsonl_all(filename: str) -> Iterable[Dict]:
"""
Streams a JSONL file.
"""
results = []
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()

return results


def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: int = 3
):
"""
Evaluates the functional correctness of a model.
"""
sample_jsonl = stream_jsonl_all(input_file)

with ThreadPoolExecutor(max_workers=n_workers) as executor:

futures = []
n_samples = 0
results = defaultdict(list)

print("Reading samples...")
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
tmp_dir_ = os.path.join(tmp_dir, "evaluation")
sample["task_id"] = task_id
sample["test_code"] = process_humaneval_test(sample)
if sample["test_code"] is None:
continue
args = (sample['test_code'],timeout)
future = executor.submit(check_correctness, *args)
futures.append(future)
n_samples += 1

print("Running test suites...")
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()

return pass_at_k


def evaluate_score(args):
gs, (c, i, o), mode = args

execution_results = []
for g in gs:
code_to_execute = f"{c}\nassert {o} == {g}"
execution_results.append(check_correctness(code_to_execute, 3))
if True not in execution_results:
execution_results = [False] * len(gs)
return execution_results
1 change: 1 addition & 0 deletions evalueate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ def estimator(n: int, c: int, k: int) -> float:
num_samples_it = iter(num_samples)

return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])

36 changes: 22 additions & 14 deletions evalueate/task/humaneval/humaneval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import json

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from evalueate.utils import language_settings, extract_generation_code
from evalueate.evaluation import evaluate_functional_correctness


def build_instruction(languge: str, question: str):
return '''
Please continue to complete the function. You are not allowed to modify the given code and do the completion only. Please return all completed function in a codeblock. Here is the given code to do completion:
Expand All @@ -9,23 +17,24 @@ def build_instruction(languge: str, question: str):
'''.strip().format(languge.lower(), question.strip())


def generate_one(example, lang, tokenizer, model):
prompt = build_instruction(languge_settings[lang]['full_name'], example['prompt'])
def generate_one(example, lang, tokenizer, model, args):
prompt = build_instruction(language_settings[lang]['full_name'], example['prompt'])
# 适配模型Chat Template
inputs = tokenizer.apply_chat_template(
[{'role': 'user', 'content': prompt}],
return_tensors="pt",
add_generation_prompt=True
).to(model.device)

stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>")
stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids("<|EOT|>")
assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found"

outputs = model.generate(
inputs,
max_new_tokens=1024,
do_sample=False,
# top_p=0.95,
# temperature=temperature,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
top_p=args.top_p,
temperature=args.temperature,
pad_token_id=stop_id,
eos_token_id=stop_id
)
Expand All @@ -37,29 +46,28 @@ def generate_one(example, lang, tokenizer, model):


def generate_main(args):
model_name_or_path = args.model
lang = args.language
model_name_or_path = args.model_name_or_path
saved_path = args.output_path
temp_dir = args.temp_dir
os.makedirs(temp_dir, exist_ok=True)
problem_file = os.path.join(data_abs_dir, f"humaneval-{lang}.jsonl")

print("model", model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path))
torch_dtype = torch.bfloat16 if args.torch_dtype == 'bf16' else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
torch_dtype=torch_dtype,
device_map="auto",
# use_flash_attention_2=True
trust_remote_code=True,
)
model.eval()
examples = [json.loads(x) for x in open(problem_file) if x.strip()]
examples = [json.loads(x) for x in open(args.data_file) if x.strip()]
print("Read {} examples for evaluation over.".format(len(examples)))

generated_examples = []
for ex in tqdm(examples, desc='Generating'):
gen_example = generate_one(ex, args.language, tokenizer, model)
gen_example = generate_one(ex, args.language, tokenizer, model, args)
generated_examples.append(gen_example)

print("Generate all over!!!")
Expand Down
58 changes: 20 additions & 38 deletions evalueate/utils.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,17 @@
import re


def extract_generation_code(example: str, lang_code: str, verbose: bool = False):
def extract_generation_code(example, lang_code):
task_id = example['task_id']
output = example.get('output', example.get("gpt_completion"))
question = example["prompt"].strip()
setting = languge_settings[lang_code]
output = example.get('output')
prompt = example["prompt"].strip()
setting = language_settings[lang_code]
lang = setting['full_name']
indent = setting['indent']

try:
code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0]
if verbose:
print(">>> Task: {}\n{}".format(task_id, code_block))

# Remove main
if setting.get('main', None) and setting['main'] in code_block:
main_start = code_block.index(setting['main'])
code_block = code_block[:main_start]

func_name, func_prefix = get_function_name(question, lang)

try:
start = code_block.lower().index(func_name.lower())
indent = 0
while start - indent >= 0 and code_block[start - indent - 1] == ' ':
indent += 1

try:
end = code_block.rindex('\n' + ' ' * indent + '}')
except:
end = len(code_block)
except:
start = 0
try:
end = code_block.rindex('\n' + ' ' * indent + '}')
except:
end = len(code_block)

body = code_block[start:end]

if lang_code.lower() in ['php', 'ts', 'js']:
body += '\n' + ' ' * indent + '}'

generation = func_prefix + '\n' + body + '\n'
assert code_block.startswith(prompt)
generation = code_block[len(prompt):]
example['generation'] = generation

except Exception as ex:
Expand All @@ -53,3 +21,17 @@ def extract_generation_code(example: str, lang_code: str, verbose: bool = False)
example['generation'] = example['prompt'] + '\n' + output

return example


language_settings = {
'python': {
'full_name': 'Python'
},
'cpp': {
'full_name': 'cpp'
},
'java': {
'full_name': 'Java'
}
}

0 comments on commit 2ef23e9

Please sign in to comment.