Skip to content

Commit

Permalink
better memory check
Browse files Browse the repository at this point in the history
also more consistent validation frequency
  • Loading branch information
lukas-blecher committed May 16, 2022
1 parent 720978d commit 24342ec
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions pix2tex/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ def embed_layer(**x):
model = Model(encoder, decoder, args)
if training:
# check if largest batch can be handled by system
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
decoder(seq, context=encoder(im)).sum().backward()
try:
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
for _ in range(5):
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
decoder(seq, context=encoder(im)).sum().backward()
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize."%(batchsize, args.max_height, args.max_width))
model.zero_grad()
torch.cuda.empty_cache()
del im, seq
Expand Down
2 changes: 1 addition & 1 deletion pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def save_models(e):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1) % args.sample_freq == 0:
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (e+1) % args.save_freq == 0:
save_models(e)
Expand Down

0 comments on commit 24342ec

Please sign in to comment.