Skip to content

Commit

Permalink
Merge pull request #49 from peidaqi/main
Browse files Browse the repository at this point in the history
Fixed Out of Memory issue when processing large datasets
  • Loading branch information
Muennighoff authored Feb 12, 2025
2 parents 2c0661c + 792ae3d commit 678550f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions data/collect_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from decontaminate_util import *

# Use smaller writer batch size, e.g. 200 for large datasets to avoid OOM. Default to 1000.
# Large datasets (>1GB): LiveCodeBench, MATH, USACO
LARGE_DATASET_WRITER_BATCH_SIZE=1000

BAD_OMNIMATH_SAMPLES = [
{"question": "Let $\\mathbb{R}$ be the set of real numbers . Determine all functions $f\u00a0: \\mathbb{R} \\rightarrow \\mathbb{R}$ such that\n \nfor all pairs of real numbers $x$ and $y$ ."},
{"question": "Find the sum of the ages of everyone who wrote a problem for this year's HMMT November contest. If your answer is $X$ and the actual value is $Y$, your score will be $\\max (0,20-|X-Y|)$"},
Expand Down Expand Up @@ -113,7 +117,8 @@ def load_generic(name, split, question_field="question", solution_field="solutio

def load_math():
ds = datasets.load_dataset("simplescaling/openaimath", trust_remote_code=True)["train"]
ds = ds.map(lambda x: {"question": x.pop("problem"), "solution": x.pop("solution"), "cot_type": "math", "source_type": "simplescaling/openaimath/" + x['subject'], "metadata": str(x)})
ds = ds.map(lambda x: {"question": x.pop("problem"), "solution": x.pop("solution"), "cot_type": "math", "source_type": "simplescaling/openaimath/" + x['subject'], "metadata": str(x)},
writer_batch_size=LARGE_DATASET_WRITER_BATCH_SIZE)
ds = ds.remove_columns([c for c in ds.column_names if c not in DS_COLUMNS])
return ds

Expand Down Expand Up @@ -262,7 +267,8 @@ def load_xword():

def load_usaco():
ds = datasets.load_dataset("codegenning/usacobench_formatted")['test']
ds = ds.map(lambda x: {"question": x.pop("question").strip(), "solution": None, "cot_type": "coding", "source_type": "codegenning/usacobench_formatted", "metadata": str(x)})
ds = ds.map(lambda x: {"question": x.pop("question").strip(), "solution": None, "cot_type": "coding", "source_type": "codegenning/usacobench_formatted", "metadata": str(x)},
writer_batch_size=LARGE_DATASET_WRITER_BATCH_SIZE)
ds = ds.remove_columns([c for c in ds.column_names if c not in DS_COLUMNS])
return ds

Expand All @@ -283,7 +289,7 @@ def load_livecodebench():
"cot_type": "coding",
"source_type": f"LiveCodeBench/{version}",
"metadata": str(x)
})
}, writer_batch_size=LARGE_DATASET_WRITER_BATCH_SIZE)
# filter only the difficult questions
ds = ds.filter(lambda x: x["difficulty"] == "hard")
ds = ds.remove_columns([c for c in ds.column_names if c not in DS_COLUMNS])
Expand Down

0 comments on commit 678550f

Please sign in to comment.