Skip to content

Commit

Permalink
display training
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxThFe committed Sep 1, 2022
1 parent 919fa25 commit 8b5a191
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@ def main(opts):
t2 = time.time()

# Training

for epoch in range(opts['nepochs']):
model.train()
batch_count = 0
batch_loss = 0
batch_loss_rec = 0
batch_loss_kl = 0
optimizer.zero_grad()

with tqdm(total=len(data_loader), position=0, leave=False) as pbar:
with tqdm(total=len(data_loader)*opts['nepochs'], position=0, leave=False) as pbar:
for epoch in range(opts['nepochs']):
model.train()
batch_count = 0
batch_loss = 0
batch_loss_rec = 0
batch_loss_kl = 0
optimizer.zero_grad()

for data in tqdm(data_loader, position=0, leave=False):
pbar.update()

Expand Down Expand Up @@ -135,7 +134,7 @@ def main(opts):
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')
parser.add_argument('--reg', type=float, default=1, help='regularization for KL loss')
parser.add_argument('--nepochs', type=int, default=1, help='number of epochs for training')
parser.add_argument('--nepochs', type=int, default=3, help='number of epochs for training')
args = parser.parse_args()

from decoder.utils import setup
Expand Down

0 comments on commit 8b5a191

Please sign in to comment.