forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for character-level language models, a new character-leve…
…l shakespeare dataset, a new config file that shows how to train a character-level baby GPT on it, and adjust the sample function to figure out if it should decode with characters or GPT2 bpe tokens. The current implementation is a bit hacky and basically assumes just these two possibilities. In the future we may want to support more general encoders or decoders.
- Loading branch information
Showing
5 changed files
with
137 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# train a miniature character-level shakespeare model | ||
# good for debugging and playing on macbooks and such | ||
|
||
out_dir = 'out-shakespeare-char' | ||
eval_interval = 250 # keep frequent because we'll overfit | ||
eval_iters = 200 | ||
log_interval = 10 # don't print too too often | ||
|
||
# we expect to overfit on this small dataset, so only save when val improves | ||
always_save_checkpoint = True | ||
|
||
wandb_log = False # override via command line if you like | ||
wandb_project = 'shakespeare-char' | ||
wandb_run_name = 'mini-gpt' | ||
|
||
dataset = 'shakespeare_char' | ||
batch_size = 64 | ||
block_size = 128 # context of up to 128 previous characters | ||
|
||
# baby GPT model :) | ||
n_layer = 4 | ||
n_head = 4 | ||
n_embd = 128 | ||
dropout = 0.0 | ||
|
||
learning_rate = 1e-3 # with baby networks can afford to go a bit higher | ||
max_iters = 5000 | ||
lr_decay_iters = 5000 # make equal to max_iters usually | ||
min_lr = 1e-4 # learning_rate / 10 usually | ||
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small | ||
|
||
warmup_iters = 100 # not super necessary potentially | ||
|
||
# on macbook also add | ||
# device = 'cpu' # run on cpu only | ||
# compile = False # do not torch compile the model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
Prepare the Shakespeare dataset for character-level language modeling. | ||
So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. | ||
Will save train.bin, val.bin containing the ids, and meta.pkl containing the | ||
encoder and decoder and some other related info. | ||
""" | ||
import os | ||
import pickle | ||
import requests | ||
import numpy as np | ||
|
||
# download the tiny shakespeare dataset | ||
if not os.path.exists('input.txt'): | ||
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | ||
with open('input.txt', 'w') as f: | ||
f.write(requests.get(data_url).text) | ||
|
||
with open('input.txt', 'r') as f: | ||
data = f.read() | ||
print("length of dataset in characters: ", len(data)) | ||
|
||
# get all the unique characters that occur in this text | ||
chars = sorted(list(set(data))) | ||
vocab_size = len(chars) | ||
print("all the unique characters:", ''.join(chars)) | ||
print("vocab size:", vocab_size) | ||
|
||
# create a mapping from characters to integers | ||
stoi = { ch:i for i,ch in enumerate(chars) } | ||
itos = { i:ch for i,ch in enumerate(chars) } | ||
def encode(s): | ||
return [stoi[c] for c in s] # encoder: take a string, output a list of integers | ||
def decode(l): | ||
''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string | ||
|
||
# create the train and test splits | ||
n = len(data) | ||
train_data = data[:int(n*0.9)] | ||
val_data = data[int(n*0.9):] | ||
|
||
# encode both to integers | ||
train_ids = encode(train_data) | ||
val_ids = encode(val_data) | ||
print(f"train has {len(train_ids)} tokens") | ||
print(f"val has {len(val_ids)} tokens") | ||
|
||
# export to bin files | ||
train_ids = np.array(train_ids, dtype=np.uint16) | ||
val_ids = np.array(val_ids, dtype=np.uint16) | ||
train_ids.tofile('train.bin') | ||
val_ids.tofile('val.bin') | ||
|
||
# save the meta information as well, to help us encode/decode later | ||
meta = { | ||
'vocab_size': vocab_size, | ||
'itos': itos, | ||
'stoi': stoi, | ||
} | ||
with open('meta.pkl', 'wb') as f: | ||
pickle.dump(meta, f) | ||
|
||
# length of dataset in characters: 1115394 | ||
# all the unique characters: | ||
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz | ||
# vocab size: 65 | ||
# train has 1003854 tokens | ||
# val has 111540 tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
# tiny shakespeare, character-level | ||
|
||
Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. | ||
|
||
After running `prepare.py`: | ||
|
||
- train.bin has 1,003,854 tokens | ||
- val.bin has 111,540 tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters