Skip to content

Commit

Permalink
update evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Sep 18, 2024
1 parent 2ef23e9 commit 091a936
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 63 deletions.
1 change: 0 additions & 1 deletion evalueate/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ class EvaluateArgs:
temperature: int = 1
model_name_or_path: str = './'
output_path: str = './'
temp_dir: str = 'tmp'
data_file: str = ''
60 changes: 39 additions & 21 deletions evalueate/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor

import numpy as np
from tqdm import tqdm

from .execution import check_correctness
from execution import check_correctness

IMPORT_HELPER = {
"python": [
Expand Down Expand Up @@ -41,28 +42,24 @@ def process_humaneval_test(sample):
return test_string


def stream_jsonl_all(filename: str) -> Iterable[Dict]:
def stream_jsonl_all(filename: str):
"""
Streams a JSONL file.
"""
results = []
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()

return results


def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: int = 3
timeout: float = 3.0,
k: int = 1
):
"""
Evaluates the functional correctness of a model.
Expand All @@ -78,30 +75,51 @@ def evaluate_functional_correctness(
print("Reading samples...")
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
tmp_dir_ = os.path.join(tmp_dir, "evaluation")
sample["task_id"] = task_id
sample["test_code"] = process_humaneval_test(sample)
if sample["test_code"] is None:
continue
args = (sample['test_code'],timeout)
args = (sample['test_code'], task_id, timeout)
future = executor.submit(check_correctness, *args)
futures.append(future)
n_samples += 1

print("Running test suites...")
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()
results[result["task_id"]].append(result)
# Calculate pass@k.
total, correct = [], []
for result in results.values():
passed = [r["passed"] for r in result]
total.append(len(passed))
correct.append(sum(passed))
total = np.array(total)
correct = np.array(correct)

pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()}
print(pass_at_k)

return pass_at_k


def evaluate_score(args):
gs, (c, i, o), mode = args
def estimate_pass_at_k(
num_samples,
num_correct,
k: int
) -> np.ndarray:
"""
Estimates pass@k and returns them in an array.
"""

def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)

execution_results = []
for g in gs:
code_to_execute = f"{c}\nassert {o} == {g}"
execution_results.append(check_correctness(code_to_execute, 3))
if True not in execution_results:
execution_results = [False] * len(gs)
return execution_results
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
15 changes: 6 additions & 9 deletions evalueate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,7 @@ class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = "stdin"


def check_correctness(check_program, timeout=3):
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
"""
def check_correctness(check_program, task_id, timeout=3):
manager = multiprocessing.Manager()
result = manager.list()

Expand All @@ -76,7 +69,11 @@ def check_correctness(check_program, timeout=3):
if not result:
result.append("timed out")

return result[0] == "passed"
return {
"task_id": task_id,
"passed": result[0] == "passed",
"result": result[0]
}


def unsafe_execute(check_program, result, timeout):
Expand Down
26 changes: 0 additions & 26 deletions evalueate/generate.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,5 @@
from typing import Union, List
import itertools

import numpy as np


def estimate_pass_at_k(
num_samples: Union[int, List[int], np.ndarray],
num_correct: Union[List[int], np.ndarray],
k: int
) -> np.ndarray:
"""
Estimates pass@k of each problem and returns them in an array.
"""

def estimator(n: int, c: int, k: int) -> float:
"""
Calculates 1 - comb(n - c, k) / comb(n, k).
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

if isinstance(num_samples, int):
num_samples_it = itertools.repeat(num_samples, len(num_correct))
else:
assert len(num_samples) == len(num_correct)
num_samples_it = iter(num_samples)

return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])

21 changes: 15 additions & 6 deletions evalueate/task/humaneval/humaneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser
import os
from evalueate.utils import language_settings, extract_generation_code
from evalueate.evaluation import evaluate_functional_correctness
from evalueate.args import EvaluateArgs


def build_instruction(languge: str, question: str):
Expand All @@ -26,7 +27,8 @@ def generate_one(example, lang, tokenizer, model, args):
add_generation_prompt=True
).to(model.device)

stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids("<|EOT|>")
stop_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.convert_tokens_to_ids(
"<|EOT|>")
assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found"

outputs = model.generate(
Expand Down Expand Up @@ -78,11 +80,18 @@ def generate_main(args):

result = evaluate_functional_correctness(
input_file=saved_path,
tmp_dir=temp_dir,
n_workers=8,
timeout=3.0,
problem_file=problem_file,
language=lang
k=1
)
print(lang, result, model_name_or_path)
print(result, model_name_or_path)


def evaluate_only(args):
pass


if __name__ == '__main__':
parser = HfArgumentParser((EvaluateArgs,))
args = parser.parse_args_into_dataclasses()[0]
generate_main(args)

0 comments on commit 091a936

Please sign in to comment.