forked from mst272/LLM-Dojo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
157 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,20 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
torch_dtype = ['bf16', 'fp16'] | ||
|
||
|
||
@dataclass | ||
class EvaluateArgs: | ||
""" | ||
配置Evaluate的参数 | ||
""" | ||
max_new_tokens: int = 100 | ||
max_length: int = 256 | ||
torch_dtype: str = 'fp16' | ||
do_sample: bool = False | ||
top_p: float = 0.95 | ||
temperature: int = 1 | ||
model_name_or_path: str = './' | ||
output_path: str = './' | ||
temp_dir: str = 'tmp' | ||
data_file: str = '' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import json | ||
import os | ||
from collections import defaultdict | ||
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor | ||
|
||
from tqdm import tqdm | ||
|
||
from .execution import check_correctness | ||
|
||
IMPORT_HELPER = { | ||
"python": [ | ||
"import math", | ||
"import re", | ||
"import sys", | ||
"import copy", | ||
"import datetime", | ||
"import itertools", | ||
"import collections", | ||
"import heapq", | ||
"import functools", | ||
"import hashlib", | ||
"import numpy", | ||
"import numpy as np", | ||
"import string", | ||
"from typing import *", | ||
"from collections import *", | ||
] | ||
} | ||
|
||
|
||
def process_humaneval_test(sample): | ||
""" | ||
Processes a sample for evaluation. | ||
return: the str of test code | ||
""" | ||
test = sample["test"] | ||
code = sample["generation"] | ||
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" | ||
test_string = test_setup + code + "\n" + test + "\n" | ||
return test_string | ||
|
||
|
||
def stream_jsonl_all(filename: str) -> Iterable[Dict]: | ||
""" | ||
Streams a JSONL file. | ||
""" | ||
results = [] | ||
if filename.endswith(".gz"): | ||
fp = gzip.open(open(filename, "rb"), "rt") | ||
else: | ||
fp = open(filename, "r") | ||
for line in fp: | ||
if any(not x.isspace() for x in line): | ||
results.append(json.loads(line)) | ||
fp.close() | ||
|
||
return results | ||
|
||
|
||
def evaluate_functional_correctness( | ||
input_file: str = None, | ||
tmp_dir: str = "./", | ||
n_workers: int = 32, | ||
timeout: int = 3 | ||
): | ||
""" | ||
Evaluates the functional correctness of a model. | ||
""" | ||
sample_jsonl = stream_jsonl_all(input_file) | ||
|
||
with ThreadPoolExecutor(max_workers=n_workers) as executor: | ||
|
||
futures = [] | ||
n_samples = 0 | ||
results = defaultdict(list) | ||
|
||
print("Reading samples...") | ||
for sample in tqdm(sample_jsonl): | ||
task_id = sample["task_id"] | ||
tmp_dir_ = os.path.join(tmp_dir, "evaluation") | ||
sample["task_id"] = task_id | ||
sample["test_code"] = process_humaneval_test(sample) | ||
if sample["test_code"] is None: | ||
continue | ||
args = (sample['test_code'],timeout) | ||
future = executor.submit(check_correctness, *args) | ||
futures.append(future) | ||
n_samples += 1 | ||
|
||
print("Running test suites...") | ||
for future in tqdm(as_completed(futures), total=len(futures)): | ||
result = future.result() | ||
|
||
return pass_at_k | ||
|
||
|
||
def evaluate_score(args): | ||
gs, (c, i, o), mode = args | ||
|
||
execution_results = [] | ||
for g in gs: | ||
code_to_execute = f"{c}\nassert {o} == {g}" | ||
execution_results.append(check_correctness(code_to_execute, 3)) | ||
if True not in execution_results: | ||
execution_results = [False] * len(gs) | ||
return execution_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters