Skip to content

Commit

Permalink
adding a lightweight configurator that may be a terrible mistake lol.…
Browse files Browse the repository at this point in the history
… also adding configs to evaluate the baseline GPT2 versions released by OpenAI on OWT. we have some ways to go to match those numbers atm
  • Loading branch information
karpathy committed Dec 28, 2022
1 parent c9fe00c commit 5d2b480
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 2 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,26 @@ Once some checkpoints are written to the output directory `out`, we're ready to
$ python sample.py
```

Training on 1 GPU overnight currently gets loss ~3.74. Random chance at init is -ln(1/50257) = 10.82. Which brings us to baselines.

## baselines

OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:

```
$ python train.py eval_gpt2
$ python train.py eval_gpt2_medium
$ python train.py eval_gpt2_large
$ python train.py eval_gpt2_xl
```

and observe the following losses on train and val:

| model | params | train loss | val loss |
| ------| ------ | ---------- | -------- |
| gpt2 | 124M | 3.11 | 3.12 |
| gpt2-medium | 350M | 2.85 | 2.84 |
| gpt2-large | 774M | 2.66 | 2.67 |
| gpt2-xl | 1558M | 2.56 | 2.54 |

I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic improvements, suggesting that OWT is not much much different from WT in terms of the data distribution, but this needs a bit more thorough attempt once the code is in a better place.
8 changes: 8 additions & 0 deletions config/eval_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=12, n_head=12, n_embd=768
# 124M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2'
8 changes: 8 additions & 0 deletions config/eval_gpt2_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=36, n_head=20, n_embd=1280
# 774M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-large'
8 changes: 8 additions & 0 deletions config/eval_gpt2_medium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=24, n_head=16, n_embd=1024
# 350M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-medium'
8 changes: 8 additions & 0 deletions config/eval_gpt2_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=48, n_head=25, n_embd=1600
# 1558M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-xl'
43 changes: 41 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,25 @@
"""

import os
import sys
import time
import math
from ast import literal_eval

import numpy as np
import torch
import wandb

from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
# settings, todo argparse or something
# default config values
# I/O
out_dir = 'out'
eval_interval = 500
log_interval = 1
eval_iters = 50
eval_only = False # if True, script exits right after the first eval
# wandb logging
wandb_log = False # disabled by default
wandb_entity = 'karpathy'
Expand Down Expand Up @@ -45,6 +50,38 @@
lr_decay_iters = 320000 # how many steps to decay the learning rate for
min_lr = 1e-5 # minimum learning rate
# -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage:
# python train.py override_file --batch_size=32
# this will first run config/override_file.py, then override batch_size to 32
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = os.path.join('config', arg + '.py')
print(f"Overriding config with {config_file}:")
with open(config_file) as f:
print(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except SyntaxError:
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
assert type(attempt) == type(globals()[key])
# cross fingers
print(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
# -----------------------------------------------------------------------------

os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337)
Expand Down Expand Up @@ -88,7 +125,7 @@ def get_batch(split):
model.to(device)

@torch.no_grad()
def estimate_loss(eval_iters=50):
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
Expand Down Expand Up @@ -166,6 +203,8 @@ def get_lr(iter):
'iter_num': iter_num,
}
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
break

X, Y = get_batch('train')
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
Expand Down

0 comments on commit 5d2b480

Please sign in to comment.