Skip to content

Commit

Permalink
Merge pull request Codium-ai#18 from Cassini-chris/patch-2
Browse files Browse the repository at this point in the history
Optimized and commented code
  • Loading branch information
mrT23 authored Jan 25, 2024
2 parents 786f5fa + e71d100 commit 288f971
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions alpha_codium/evaluate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,31 @@

logger = get_logger(__name__)


def evaluate_dataset_solution(dataset_name='valid_and_test_processed',
split_name='test',
solution_path_database='valid_database_solution.json'):
split_name = split_name
dataset_name = dataset_name
data_provider = CodeContestDataProvider(dataset_location=dataset_name)

ds = data_provider.dataset[split_name]

solution_path_database = solution_path_database
split_name='test',
solution_path_database='valid_database_solution.json'):
"""
Evaluate the performance of dataset solutions.
Args:
dataset_name (str, optional): The name of the dataset. Defaults to 'valid_and_test_processed'.
split_name (str, optional): The name of the split. Defaults to 'test'.
solution_path_database (str, optional): The path to the solution database file. Defaults to 'valid_database_solution.json'.
"""

# Load the dataset and solution database
data_provider = CodeContestDataProvider(dataset_location=dataset_name)
ds = data_provider.dataset[split_name]
with open(solution_path_database, 'r') as f:
database_solutions = json.load(f)
database_solutions[split_name] = OrderedDict(
sorted(database_solutions[split_name].items(), key=lambda x: int(x[0])))

# Initialize counters for passed and failed problems
total_passed = 0
total_failed = 0

# Iterate over the solutions in the database
for sol in database_solutions[split_name]:
try:
key_str = sol
Expand Down Expand Up @@ -68,9 +74,12 @@ def evaluate_dataset_solution(dataset_name='valid_and_test_processed',
print(f"Error: {e}")
pass

# Print the total number of passed and failed problems
print(f"total_passed: {total_passed}, total_failed: {total_failed}")
print(f"pass rate: {total_passed/(total_passed+total_failed)}")

# Calculate the pass rate
pass_rate = total_passed / (total_passed + total_failed)
print(f"pass rate: {pass_rate}")

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="valid_and_test_processed")
Expand Down

0 comments on commit 288f971

Please sign in to comment.