Skip to content

Commit

Permalink
watch model directly
Browse files Browse the repository at this point in the history
  • Loading branch information
TITC committed May 17, 2022
1 parent 0aefdbf commit 5cbbcb9
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions pix2tex/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,10 @@ def get_model(args, training=False):
decoder = nn.DataParallel(decoder)
encoder.to(args.device)
decoder.to(args.device)
model = Model(encoder, decoder, args)
if args.wandb:
import wandb
en_attn_layers = encoder
if args.encoder_structure.lower() == 'vit':
en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
wandb.watch((en_attn_layers, de_attn_layers))
model = Model(encoder, decoder, args)
wandb.watch(model)
if training:
# check if largest batch can be handled by system
batchsize = args.batchsize if args.get(
Expand Down

0 comments on commit 5cbbcb9

Please sign in to comment.