Skip to content

Commit

Permalink
shuttling the poor mans configurator aside into its own file and addi…
Browse files Browse the repository at this point in the history
…ng it to all of train,sample,bench. because i am leaving args in globals() so i can avoid having to prepend every single variable with an args., i have to exec the configurator and the optional configs. so we're left with something very gross by standard convention but also quite simple and functional. *ducks*
  • Loading branch information
karpathy committed Jan 5, 2023
1 parent ab04701 commit d562b3e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 41 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ I briefly tried finetuning gpt2 a bit more on our OWT and didn't notice dramatic
For an example of how to finetune a GPT on new text go to `data/shakespeare` and look at `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`. Unlike OpenWebText this will run in seconds. Finetuning takes very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like:

```
$ python train.py finetune_shakespeare
$ python train.py config/finetune_shakespeare.py
```

This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py` to generate infinite Shakespeare. Note that you'll have to edit it to point to the correct `out_dir`.
Expand Down Expand Up @@ -102,7 +102,6 @@ Features / APIs
- Add back fp16 support? (would need to also add back gradient scaler)
- Add CPU support
- Finetune the finetuning script, I think the hyperparams are not great
- Replace poor man's configurator, and make sample.py configurable...
- Report and track other metrics e.g. perplexity, num_tokens, MFU, ...
- Eval zero-shot perplexities on PTB, WikiText, other related benchmarks

Expand Down
13 changes: 8 additions & 5 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
import torch
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
device = 'cuda'
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.manual_seed(1337)

batch_size = 8
block_size = 1024
dtype = torch.bfloat16
compile = True
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

dtype = torch.bfloat16 # todo make configurable
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.manual_seed(1337)

# data loading init
real_data = True
Expand Down
47 changes: 47 additions & 0 deletions configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Poor Man's Configurator. Probably a terrible idea. Example usage:
$ python train.py config/override_file.py --batch_size=32
this will first run config/override_file.py, then override batch_size to 32
The code in this file will be run as follows from e.g. train.py:
>>> exec(open('configurator.py').read())
So it's not a Python module, it's just shuttling this code away from train.py
The code in this script then overrides the globals()
I know people are not going to love this, I just really dislike configuration
complexity and having to prepend config. to every single variable. If someone
comes up with a better simple Python solution I am all ears.
"""

import sys
from ast import literal_eval

for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print(f"Overriding config with {config_file}:")
with open(config_file) as f:
print(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
assert type(attempt) == type(globals()[key])
# cross fingers
print(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
2 changes: 1 addition & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
# todo make these overridable like in train.py
out_dir = 'out'
device = 'cuda:2'
compile = False
Expand All @@ -17,6 +16,7 @@
temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
Expand Down
35 changes: 2 additions & 33 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import sys
import time
import math
from ast import literal_eval

import wandb
import numpy as np
Expand All @@ -24,7 +23,7 @@
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
# default config values
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
Expand Down Expand Up @@ -62,37 +61,7 @@
backend = 'nccl' # 'nccl', 'gloo', etc.
compile = True # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage:
# $ python train.py override_file --batch_size=32
# this will first run config/override_file.py, then override batch_size to 32
for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = os.path.join('config', arg + '.py')
print(f"Overriding config with {config_file}:")
with open(config_file) as f:
print(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
assert type(attempt) == type(globals()[key])
# cross fingers
print(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run?
if ddp:
Expand Down

0 comments on commit d562b3e

Please sign in to comment.