Skip to content

Commit

Permalink
who needs a dataloader? overlap the prefetching of the next batch wit…
Browse files Browse the repository at this point in the history
…h GPU compute, ehiding the data loading latency entirely. this saves about 1ms lol
  • Loading branch information
karpathy committed Feb 4, 2023
1 parent 46428d3 commit 3fd4c0c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
10 changes: 6 additions & 4 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_batch(split):
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.to(device), y.to(device)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
return x, y
else:
# alternatively, if fixed data is desired to not care about data loading
Expand Down Expand Up @@ -76,14 +76,15 @@ def get_batch(split):
record_shapes=False,
profile_memory=False,
with_stack=False, # incurs an additional overhead, disable if not needed
with_flops=False,
with_flops=True,
with_modules=False, # only for torchscript models atm
) as prof:

X, Y = get_batch('train')
for k in range(num_steps):
X, Y = get_batch('train')
with ctx:
logits, loss = model(X, Y)
X, Y = get_batch('train')
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Expand All @@ -98,10 +99,11 @@ def get_batch(split):
torch.cuda.synchronize()
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
t0 = time.time()
X, Y = get_batch('train')
for k in range(num_steps):
X, Y = get_batch('train')
with ctx:
logits, loss = model(X, Y)
X, Y = get_batch('train')
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Expand Down
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def get_batch(split):
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.to(device), y.to(device)
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
Expand Down Expand Up @@ -227,6 +228,7 @@ def get_lr(it):
wandb.init(project=wandb_project, name=wandb_run_name, config=config)

# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
while True:

Expand Down Expand Up @@ -269,8 +271,6 @@ def get_lr(it):
# forward backward update, with optional gradient accumulation to simulate larger batch size
# and using the GradScaler if data type is float16
for micro_step in range(gradient_accumulation_steps):
# fetch a batch
X, Y = get_batch('train')
if ddp:
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
Expand All @@ -279,6 +279,8 @@ def get_lr(it):
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
with ctx:
logits, loss = model(X, Y)
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch('train')
# backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward()
# clip the gradient
Expand Down

0 comments on commit 3fd4c0c

Please sign in to comment.