forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtinystories.py
108 lines (92 loc) · 3.97 KB
/
tinystories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Downloads and tokenizes the TinyStories dataset.
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created tinystories/ folder.
The script prints:
Tokenizing val split...
Saved 19043638 tokens to tinystories/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to tinystories/TinyStories_train.bin
And runs in 1-2 minutes two depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""
import os
import glob
import json
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import tiktoken
import numpy as np
from data_common import download_file, write_datafile
# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinystories")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyStories dataset, unless it's already downloaded
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
# unpack the tar.gz file into all the data shards (json files)
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
else:
print(f"{data_dir} already exists, skipping unpacking...")
# print a single example just for debugging and such
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
print("Download done.")
print(f"Number of shards: {len(shard_filenames)}")
# with open(shard_filenames[0], "r") as f:
# data = json.load(f)
# print(f"Example story:\n{data[0]}")
def process_shard(shard_index, shard_filename):
with open(shard_filename, "r") as f:
data = json.load(f)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
rng = random.Random(1337 + shard_index)
rng.shuffle(data)
all_tokens = []
for example in data:
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = encode(text)
all_tokens.append(eot)
all_tokens.extend(tokens)
return all_tokens
def tokenize():
# shard 0 will be the val split, rest is train
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
val_shards = [shard_filenames[0]]
train_shards = shard_filenames[1:]
for split_name, split_shards in [("val", val_shards), ("train", train_shards)]:
print(f"Tokenizing {split_name} split...")
all_tokens = []
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_shard, shard_index, shard_filename)
for shard_index, shard_filename in enumerate(split_shards)]
for future in as_completed(futures):
all_tokens.extend(future.result())
split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
write_datafile(split_filename, all_tokens)
if __name__ == "__main__":
download()
tokenize()
# Prints:
# Tokenizing val split...
# Saved 19043638 tokens to data/TinyStories_val.bin
# Tokenizing train split...
# Saved 925653391 tokens to data/TinyStories_train.bin