Skip to content

Commit

Permalink
add support for DDP training. the scaling timings right now do not lo…
Browse files Browse the repository at this point in the history
…ok good by default, have to dig more into
  • Loading branch information
karpathy committed Dec 29, 2022
1 parent ee6459f commit dea1507
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 39 deletions.
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@

# nanoGPT

The cleanest, fastest repository for training/finetuning medium-sized GPTs. Still under active development, currently trying to reproduce GPT-2 on OpenWebText dataset. The code itself is tiny, plain and readable. At the moment `train.py` is a ~200-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can also load the GPT-2 weights from OpenAI.
The cleanest, smallest, fastest repository for training/finetuning medium-sized GPTs. Still under active development, currently working to reproduce GPT-2 on OpenWebText dataset. The code itself aims by design to be plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it.

## install

We need a few dependencies:
Dependencies:

- [pytorch](https://pytorch.org), of course
- numpy
- `pip install datasets` for huggingface datasets
- `pip install tiktoken` for OpenAI's fast bpe code
- `pip install wandb` for optional logging
- [pytorch](https://pytorch.org) <3
- numpy <3
- `pip install datasets` for huggingface datasets <3
- `pip install tiktoken` for OpenAI's fast bpe code <3
- `pip install wandb` for optional logging <3

## usage

To render a dataset we first tokenize some documents into one giant array of indices. E.g. for OpenWebText see:
To render a dataset we first tokenize some documents into one simple long 1D array of indices. E.g. for OpenWebText see:

```
$ cd data/openwebtext
$ python prepare.py
```

To download and tokenize the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. The training script currently by default tries to reproduce the smallest GPT-2 released by OpenAI, i.e. the 124M version of GPT-2. We can train as follows, though I encourage you to read the code and see all of the settings and paths up top in the file:
To download and tokenize the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. This will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. The training script currently by default tries to reproduce the smallest GPT-2 released by OpenAI, i.e. the 124M version of GPT-2. We can demo train as follows on a single device, though I encourage you to read the code and see all of the settings and paths up top in the file:

```
$ python train.py
```

Once some checkpoints are written to the output directory `out`, we can sample from the model:
To train using PyTorch Distributed Data Parallel (DDP) run the script with torchrun. For example to train on a node with 4 GPUs run:

```
$ torchrun --standalone --nproc_per_node=4 train.py
```

Once some checkpoints are written to the output directory (e.g. `./out` by default), we can sample from the model:

```
$ python sample.py
Expand Down
81 changes: 52 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Train a GPT model on a dataset of text. One GPU version.
The text is assumed to pre-tokenized and inside files train.pt and val.pt
This training script can be run both on a single gpu in debug mode,
and also in a larger training run with distributed data parallel (ddp).
To run in debug mode example:
$ python train.py --batch_size=32 --other=args
To run DDP on 4 gpus on one node, example:
$ torchrun --standalone --nproc_per_node=4 train.py
"""

import os
Expand All @@ -9,9 +15,11 @@
import math
from ast import literal_eval

import wandb
import numpy as np
import torch
import wandb
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

from model import GPTConfig, GPT

Expand Down Expand Up @@ -49,9 +57,11 @@
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 320000 # how many steps to decay the learning rate for
min_lr = 1e-5 # minimum learning rate
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# -----------------------------------------------------------------------------
# poor man's Configurator. Potentially a bad idea. Example usage:
# python train.py override_file --batch_size=32
# $ python train.py override_file --batch_size=32
# this will first run config/override_file.py, then override batch_size to 32
for arg in sys.argv[1:]:
if '=' not in arg:
Expand All @@ -71,7 +81,7 @@
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except SyntaxError:
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
Expand All @@ -82,13 +92,21 @@
else:
raise ValueError(f"Unknown config key: {key}")
# -----------------------------------------------------------------------------

os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337)
ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run?
if ddp:
init_process_group(backend=backend)
gpu_id = int(os.environ["LOCAL_RANK"])
device = f"cuda:{gpu_id}"
else:
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically

if gpu_id == 0:
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

# poor man's data loader, TODO use real DataLoader...
# poor man's data loader, TODO evaluate need for actual DataLoader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
Expand All @@ -101,16 +119,16 @@ def get_batch(split):
return x, y

# model init
# TODO I don't love this whole part/API yet
model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout)
if init_from == 'scratch':
# init a new model from scratch
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
elif init_from == 'resume':
# resume training from a checkpoint. TODO: do we resume iter_num etc too? (yes...)
# resume training from a checkpoint.
# TODO: should we also resume iter_num and best_val_loss?
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args']
for k, v in model_args.items():
assert checkpoint_model_args[k] == v, "for now"
Expand All @@ -120,10 +138,20 @@ def get_batch(split):
elif init_from.startswith('gpt2'):
# initialize from OpenAI GPT-2 weights
model = GPT.from_pretrained(init_from)
if block_size < model.block_size:
model.crop_block_size(block_size)
# crop down the model block size if desired
if block_size < model.block_size:
model.crop_block_size(block_size)
model.to(device)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])

# wrap model into DDP container
if ddp:
model = DDP(model, device_ids=[gpu_id])

@torch.no_grad()
def estimate_loss():
out = {}
Expand All @@ -139,11 +167,6 @@ def estimate_loss():
model.train()
return out

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, betas)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])

# learning rate decay scheduler (cosine with warmup)
def get_lr(iter):
# 1) linear warmup for warmup_iters steps
Expand All @@ -155,11 +178,11 @@ def get_lr(iter):
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # ranges 0..1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)

# logging
if wandb_log:
if wandb_log and gpu_id == 0:
wandb.init(project=wandb_project, entity=wandb_entity, name=wandb_run_name)
wandb.config = {
"batch_size": batch_size,
Expand All @@ -169,7 +192,6 @@ def get_lr(iter):

# training loop
iter_num = 0
num_tokens = 0
best_val_loss = 1e9
t0 = time.time()
while True:
Expand All @@ -182,25 +204,26 @@ def get_lr(iter):
else:
lr = learning_rate

if iter_num % eval_interval == 0:
if iter_num % eval_interval == 0 and gpu_id == 0:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if wandb_log:
wandb.log({
"iter": iter_num,
"num_tokens": num_tokens,
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": lr,
})
if losses['val'] < best_val_loss:
best_val_loss = losses['val']
if iter_num > 0: # don't save checkpoints on very first iteration...
raw_model = model.module if ddp else model
if iter_num > 0:
checkpoint = {
'model': model.state_dict(),
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': model_args,
'iter_num': iter_num,
'best_val_loss': best_val_loss,
}
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
Expand All @@ -212,19 +235,19 @@ def get_lr(iter):

optimizer.zero_grad(set_to_none=True)
loss.backward()
# TODO: gradient clipping
# TODO: gradient clipping evaluate need for
optimizer.step()

t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0:
if iter_num % log_interval == 0 and gpu_id == 0:
lossf = loss.item() # loss as float. TODO CPU-GPU sync: profile, make sure not slow af
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
iter_num += 1
num_tokens += X.numel()

# termination conditions
if iter_num >= max_iters:
break

destroy_process_group()

0 comments on commit dea1507

Please sign in to comment.