Skip to content

Commit

Permalink
Merge pull request stanford-crfm#170 from stanford-crfm/gsm
Browse files Browse the repository at this point in the history
Grade School Math with 8.5K Examples (GSM8K)
  • Loading branch information
ezelikman authored Mar 22, 2022
2 parents 486dcf0 + bdb64cc commit 60516df
Showing 5 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ sqlitedict==1.7.0
transformers==4.13.0
tornado==6.1
zstandard==0.17.0
jsonlines==3.0.0

# For development
black==19.10b0
1 change: 1 addition & 0 deletions src/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from . import lpm_scenario # noqa
from . import copyright_scenario # noqa
from . import boolq_scenario # noqa
from . import gsm_scenario # noqa
from . import natural_qa_scenario # noqa
from . import quac_scenario # noqa
from . import babi_qa_scenario # noqa
16 changes: 16 additions & 0 deletions src/benchmark/basic_metrics.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,21 @@ def exact_match(gold: str, pred: str) -> float:
return 1 if gold == pred else 0


def exact_match_indicator(gold: str, pred: str) -> float:
"""
Exact match, allowing for some preceding context.
For example, the following two answers are considered matching:
- Because of x and y, the answer is ## <answer>
- Given reasons y and z, the answer is ## <answer>
While the following is considered different from the earlier two
- Given reasons x and a, the answer is ## <other answer>
"""
indicator: str = "#"
pred = pred.split(indicator)[-1].strip()
gold = gold.split(indicator)[-1].strip()
return exact_match(gold, pred)


def get_num_bytes(tokens: List[Token]) -> int:
"""
Compute the byte length of the input tokens. For a UTF-8 string token, we use byte() to convert
@@ -189,6 +204,7 @@ def compute_metrics_helper(name: MetricName, score_func: Callable[[str, str], fl
# maps each string metric name to its associated function
metric_fn_mapping = {
"exact_match": exact_match,
"exact_match_indicator": exact_match_indicator,
"exact_set_match": exact_set_match,
"iou_set_match": iou_set_match,
"f1_score": f1_score,
47 changes: 47 additions & 0 deletions src/benchmark/gsm_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jsonlines
import os
from typing import List

from common.general import ensure_file_downloaded
from .scenario import Scenario, Instance, Reference, CORRECT_TAG, TRAIN_SPLIT, TEST_SPLIT


class GSM8KScenario(Scenario):
"""Task from "Training Verifiers to Solve Math Word Problems" (Cobbe et al. 2021): https://arxiv.org/abs/2110.14168
Evaluates the capacity of a model to solve grade school math problems, when prompted to include reasoning.
Encourages the model to work through the problem in a step-by-step way
Example from dataset (line breaks added for readability):
"question":
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May.
How many clips did Natalia sell altogether in April and May?",
"answer":
"Natalia sold 48/2 = <<48/2=24>>24 clips in May.\n
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n
#### 72"
"""

name = "gsm"
description = "Grade school math dataset with 8.5K examples (GSM8K)."
tags = ["reasoning", "math"]

def __init__(self):
pass

def get_instances(self) -> List[Instance]:
splits = {"train": TRAIN_SPLIT, "test": TEST_SPLIT}
base_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/"
instances: List[Instance] = []
for split, split_tag in splits.items(): # Iterate over the splits
source_url: str = f"{base_url}/{split}.jsonl"
data_path: str = os.path.join(self.output_path, f"gsm_data_{split}")
ensure_file_downloaded(source_url=source_url, target_path=data_path)
with jsonlines.open(data_path) as reader:
for example in reader: # Each example is a dictionary with a 'question' and 'answer' key
instances.append(
Instance(
input=example["question"],
references=[Reference(output=example["answer"], tags=[CORRECT_TAG])],
split=split_tag, # Must assign split tag to instance.
),
)
return instances
26 changes: 26 additions & 0 deletions src/benchmark/run_specs.py
Original file line number Diff line number Diff line change
@@ -111,6 +111,8 @@ def construct_run_specs(spec: ObjectSpec) -> List[RunSpec]:
return [get_run_spec1()]
if name == "twitter_aae":
return [get_twitter_aae_spec(**args)]
if name == "gsm":
return [get_gsm_spec()]
if name == "natural_qa":
return [get_natural_qa_spec(**args)]
if name == "the_pile":
@@ -305,6 +307,30 @@ def get_real_toxicity_prompts_spec() -> RunSpec:
)


def get_gsm_spec() -> RunSpec:
scenario = ScenarioSpec(class_name="benchmark.gsm_scenario.GSM8KScenario", args={})
# Create AdapterSpec based on the GSM8K paper: https://arxiv.org/pdf/2110.14168.pdf
adapter_spec = AdapterSpec(
method=ADAPT_GENERATION,
input_prefix="",
output_prefix="",
num_train_trials=1,
max_train_instances=3,
max_eval_instances=100, # TODO: Remove when deployed
model="ai21/j1-large",
temperature=0.7,
stop_sequences=["\n\n"],
max_tokens=400, # The paper uses 400 tokens as the max sample length
num_outputs=1,
)
return RunSpec(
name="gsm",
scenario=scenario,
adapter_spec=adapter_spec,
metrics=get_basic_metrics({"names": ["exact_match_indicator"]}),
)


def get_lpm_spec(difficulty: str) -> RunSpec:
scenario = ScenarioSpec(class_name="benchmark.lpm_scenario.LPMScenario", args={"difficulty": difficulty})

0 comments on commit 60516df

Please sign in to comment.