Skip to content

Commit

Permalink
Merge pull request THUDM#76 from roG0d/main
Browse files Browse the repository at this point in the history
ADDED: humaneval-x rust integration
  • Loading branch information
Stanislas0 authored Mar 6, 2023
2 parents 394df41 + f66205b commit 0abbe13
Show file tree
Hide file tree
Showing 9 changed files with 549 additions and 7 deletions.
239 changes: 239 additions & 0 deletions codegeex/benchmark/evaluate_humaneval_x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import os
import sys
import fire
import json
import gzip
import regex
import numpy as np

from typing import *
from tqdm.auto import tqdm
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

from codegeex.benchmark.utils import read_dataset, IMPORT_HELPER
from codegeex.benchmark.metric import estimate_pass_at_k
from codegeex.benchmark.execution import check_correctness

LANGUAGE_NAME = {
"cpp" : "CPP",
"go" : "Go",
"java" : "Java",
"js" : "JavaScript",
"python": "Python",
}


def process_humaneval_test(sample, problems, example_test=False):
task_id = sample["task_id"]
language = task_id.split("/")[0].lower()

prompt = sample["prompt"]
if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
test = problems[task_id]["example_test"]
else:
test = problems[task_id]["test"]
code = sample["generation"]

# Pre-process for different languages
if language == "python":
code_ = []
for line in code.split("\n"):
if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'):
break
code_.append(line)
code = "\n".join(code_)
test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
test_string = test_setup + prompt + code + "\n" + test + "\n"
elif language == "cpp":
test_set_up = ""
for s in IMPORT_HELPER["cpp"]:
if s not in prompt:
test_set_up += s + "\n"
test_string = test_set_up + "\n" + prompt + code + "\n" + test
elif language == "java":
test_string = prompt + code + "\n" + test
elif language == "js" or language == "javascript":
test_string = prompt + code + "\n" + test
elif language == "go":
import_string = problems[task_id]["import"]
prompt = prompt.replace(import_string, "")
if example_test and "example_test" in problems[task_id]:
test = problems[task_id]["example_test"]
else:
test = problems[task_id]["test"]
test_setup = problems[task_id]["test_setup"]
other_pkgs = []
for pkg in IMPORT_HELPER["go"]:
if pkg not in test_setup:
p = pkg.split("/")[-1]
if p + "." in code:
other_pkgs.append(f"\"{pkg}\"")
if other_pkgs:
import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")"
test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
else:
test_string = test_setup + "\n" + prompt + code + "\n" + test
elif language == "rust":
main = "\nfn main(){ \n } \n"
declaration = problems[task_id]["declaration"]
test_string = main + declaration + prompt + code + test

return test_string


def stream_jsonl_all(filename: str) -> Iterable[Dict]:
results = []
if filename.endswith(".gz"):
fp = gzip.open(open(filename, "rb"), "rt")
else:
fp = open(filename, "r")
for line in fp:
if any(not x.isspace() for x in line):
results.append(json.loads(line))
fp.close()

return results


def evaluate_functional_correctness(
input_file: str = None,
tmp_dir: str = "./",
n_workers: int = 32,
timeout: float = 500.0,
problem_file: str = "../data/humaneval_python.jsonl.gz",
out_dir: str = None,
k: List[int] = [1, 10, 100],
test_groundtruth: bool = False,
example_test: bool = False,
):
if example_test:
print("Example test...")

problems = read_dataset(problem_file,
dataset_type="humaneval")
sample_jsonl = stream_jsonl_all(input_file)

if example_test:
suffix = "_example_test.jsonl"
else:
suffix = "_results.jsonl"
if out_dir is not None:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix))
else:
out_file = os.path.join(input_file.replace(".jsonl", suffix))

if "/codegeex/benchmark/humaneval-x/" in input_file:
test_groundtruth = True

if "-to-" in input_file:
translation_mode = True
else:
translation_mode = False

with ThreadPoolExecutor(max_workers=n_workers) as executor:

futures = []
completion_id = Counter()
n_samples = 0
results = defaultdict(list)

if test_groundtruth:
print("Testing ground truth...")
for sample in tqdm(problems.values()):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["generation"] = sample["canonical_solution"]
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
if sample["test_code"] is None:
continue
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id])
future = executor.submit(check_correctness, *args)
futures.append(future)
completion_id[task_id] += 1
n_samples += 1
else:
print("Reading samples...")
for sample in tqdm(sample_jsonl):
task_id = sample["task_id"]
lang = task_id.split("/")[0].lower()
if translation_mode:
task_id = sample["task_id"].split("/")[-1]
lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-")
for l in LANGUAGE_NAME:
if l in lang:
lang = l
break
task_id = f"{LANGUAGE_NAME[lang]}/{task_id}"
if lang == "javascript":
lang = "js"
tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation")
sample["task_id"] = task_id
sample["test_code"] = process_humaneval_test(sample, problems, example_test)
if sample["test_code"] is None:
continue
if "completion_id" in sample:
completion_id_ = sample["completion_id"]
else:
completion_id_ = completion_id[task_id]
args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_)
future = executor.submit(check_correctness, *args)
futures.append(future)
completion_id[task_id] += 1
n_samples += 1

print(completion_id)
if len(completion_id) == len(problems):
evaluate_pass_at_k = True
else:
evaluate_pass_at_k = False

print("Running test suites...")
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()
results[result["task_id"]].append((result["completion_id"], result))

# Calculate pass@k.
total, correct = [], []
for result in results.values():
passed = [r[1]["passed"] for r in result]
total.append(len(passed))
correct.append(sum(passed))
total = np.array(total)
correct = np.array(correct)
if evaluate_pass_at_k:
ks = k
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean()
for k in ks if (total >= k).all()}
print(pass_at_k)
else:
print("Total:", np.sum(total))
print("Correct:", np.sum(correct))

print("Writing to: ", out_file)
if out_file.endswith(".gz"):
fp = gzip.GzipFile(fileobj=open(out_file, "wb"), mode="wb")
for res in results.values():
for r in res:
fp.write((json.dumps(r[1]) + "\n").encode("utf-8"))
else:
fp = open(out_file, 'w')
for res in results.values():
for r in res:
fp.write(json.dumps(r[1]) + "\n")
fp.close()

print("Evaluation finished.")


def main():
fire.Fire(evaluate_functional_correctness)


if __name__ == "__main__":
sys.exit(main())
107 changes: 101 additions & 6 deletions codegeex/benchmark/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,38 @@
import random
import subprocess
import tempfile
import gzip
import json
from typing import *

def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> None:
"""
Method saves list of dicts into jsonl file.
:param data: (list) list of dicts to be stored,
:param filename: (str) path to the output file. If suffix .jsonl is not given then methods appends
.jsonl suffix into the file.
:param compress: (bool) should file be compressed into a gzip archive?
"""
sjsonl = '.jsonl'
sgz = '.gz'
# Check filename
if not filename.endswith(sjsonl):
filename = filename + sjsonl
# Save data

if compress:
filename = filename + sgz
with gzip.open(filename, 'w') as compressed:
for ddict in data_list:
jout = json.dumps(ddict) + '\n'
jout = jout.encode('utf-8')
compressed.write(jout)
else:
with open(filename, 'w') as out:
for ddict in data_list:
jout = json.dumps(ddict) + '\n'
out.write(jout)


def check_correctness(
task_id: str,
Expand Down Expand Up @@ -52,8 +82,8 @@ def unsafe_execute(tmp_dir):
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec(sample["test_code"], exec_globals)
result.append("passed")
exec(sample["test_code"], exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except AssertionError as e:
Expand Down Expand Up @@ -92,7 +122,7 @@ def unsafe_execute(tmp_dir):
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)

if exec_result.returncode == 0:
result.append("passed")
Expand Down Expand Up @@ -137,7 +167,7 @@ def unsafe_execute(tmp_dir):
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)

if exec_result.stderr.decode():
err = exec_result.stderr.decode()
Expand Down Expand Up @@ -190,7 +220,7 @@ def unsafe_execute(tmp_dir):
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)

if exec_result.returncode == 0:
result.append("passed")
Expand All @@ -210,6 +240,71 @@ def unsafe_execute(tmp_dir):
result.append("timed out")

shutil.rmtree(tmp_dir)
elif "rust" in language_type.lower():
import os

WD: str = os.path.dirname(os.path.abspath(__file__))
RUST_DIR: str = os.path.join(WD, "rust")
RUST_SRC: str = os.path.join(RUST_DIR, "src")
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
RUST_EXT: str = ".rs"

# Create mandatory tmp directories
os.makedirs(RUST_TMP_DIR, exist_ok=True)
os.makedirs(RUST_LOGS, exist_ok=True)
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)

with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
#temporal file name
file_prefix = sample["task_id"].lower().replace("/", "_")
file_name:str = file_prefix +RUST_EXT

os.rename(f.name, os.path.join(RUST_BIN, file_name))

# Sample to pure Rust function
rust_code: str = sample["test_code"]

# dump the rust source code in the target temporal file
f.write(rust_code.encode('utf-8'))

# Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
os.chdir(RUST_DIR)

# Two possible outcomes
# Pass OR Fail compilation
log_filename: str = file_prefix + ".jsonl"
log_path: str = os.path.join(RUST_LOGS, log_filename)
cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path
# Compilation build status
returned_val_compilation: int

# Overwrite file content
if os.path.exists(log_path):
if(file_size := os.path.getsize(log_path)) >= 0:
os.remove(log_path)
returned_val_compilation = os.system(cargo_check)

else:
returned_val_compilation = os.system(cargo_check)

# 0 means success
if returned_val_compilation == 0:

#Execution pipeline
cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path
returned_val_execution = os.system(cargo_test)

if returned_val_execution == 0:
result.append("passed")
else:
result.append(f"failed: execution error")

else:
result.append(f"failed: compilation error")


elif "java" in language_type.lower():
assert tmp_dir is not None, "Java should be evaluated in a temporary dir."
Expand Down Expand Up @@ -264,7 +359,7 @@ def unsafe_execute(tmp_dir):
result.append(res)

shutil.rmtree(tmp_dir)

manager = multiprocessing.Manager()
result = manager.list()

Expand Down
Loading

0 comments on commit 0abbe13

Please sign in to comment.