Skip to content

Commit

Permalink
Merge pull request karpathy#106 from YassineYousfi/master
Browse files Browse the repository at this point in the history
use the ``enabled`` arg in GradScaler
  • Loading branch information
karpathy authored Feb 3, 2023
2 parents 1e87509 + 40f4d6f commit 7d44bdf
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,8 @@ def get_batch(split):
model.crop_block_size(block_size)
model.to(device)

# initialize a GradScaler if data type is float16
scaler = None
if dtype == 'float16':
print(f"Initializing Gradient Scaler to account for dtype: {dtype}")
scaler = torch.cuda.amp.GradScaler()
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
Expand Down Expand Up @@ -283,17 +280,14 @@ def get_lr(it):
with ctx:
logits, loss = model(X, Y)
# backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward() if scaler else loss.backward()
scaler.scale(loss).backward()
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer) if scaler else None
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

Expand Down

0 comments on commit 7d44bdf

Please sign in to comment.