Skip to content

Commit

Permalink
display
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxThFe committed Sep 1, 2022
1 parent aad119d commit 919fa25
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
Binary file modified data/__pycache__/cycles.cpython-310.pyc
Binary file not shown.
7 changes: 4 additions & 3 deletions data/cycles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
from torch.utils.data import Dataset
from decoder.utils import convert_cycle
from tqdm import tqdm

def get_previous(i, v_max):
if i == 0:
Expand Down Expand Up @@ -176,7 +177,7 @@ def rollout_and_examine(self, model, num_samples):
print("Cycle ratio", self.cycle_ratio)
print('Valdi ratio', self.valid_ratio)
print("-"*25)

def write_summary(self):

def _format_value(v):
Expand Down Expand Up @@ -215,11 +216,11 @@ def __init__(self, num_epochs, num_batches):
self.num_batches = num_batches
self.batch_count = 0

def update(self, epoch, metrics):
def update(self, epoch, metrics, pbar):
self.batch_count = (self.batch_count) % self.num_batches + 1

msg = 'epoch {:d}/{:d}, batch {:d}/{:d}'.format(epoch, self.num_epochs,
self.batch_count, self.num_batches)
for key, value in metrics.items():
msg += ', {}: {:4f}'.format(key, value)
print(msg)
pbar.write(msg)
62 changes: 33 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import datetime
import time
from tqdm import tqdm
from tqdm.auto import tqdm

import torch
from torch.optim import Adam
Expand Down Expand Up @@ -52,34 +52,38 @@ def main(opts):
batch_loss_kl = 0
optimizer.zero_grad()

for i, data in enumerate(tqdm(data_loader, position=0, leave=True)):
x, edge_index = convert_cycle(data)
loss, loss_rec, loss_kl = model(x, edge_index, actions=data) # train on data
loss.backward() # backpropagate

batch_loss += loss.item()
batch_loss_rec += -loss_rec.item()
batch_loss_kl += loss_kl.item()
#batch_prob += prob_averaged.item()
batch_count += 1

if batch_count % opts['batch_size'] == 0:
print('\n')
printer.update(epoch + 1, {'averaged_loss': batch_loss/opts['batch_size'], \
'reconstruction_loss': batch_loss_rec/opts['batch_size'], \
'kl_loss': batch_loss_kl/opts['batch_size'],})


if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])

optimizer.step()

batch_loss = 0
#
optimizer.zero_grad()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
with tqdm(total=len(data_loader), position=0, leave=False) as pbar:
for data in tqdm(data_loader, position=0, leave=False):
pbar.update()

#for i, data in enumerate(tqdm(data_loader, position=0, leave=True)):
x, edge_index = convert_cycle(data)
loss, loss_rec, loss_kl = model(x, edge_index, actions=data) # train on data
loss.backward() # backpropagate

batch_loss += loss.item()
batch_loss_rec += -loss_rec.item()
batch_loss_kl += loss_kl.item()
#batch_prob += prob_averaged.item()
batch_count += 1

if batch_count % opts['batch_size'] == 0:
#print('\n')
printer.update(epoch + 1, {'averaged_loss': batch_loss/opts['batch_size'], \
'reconstruction_loss': batch_loss_rec/opts['batch_size'], \
'kl_loss': batch_loss_kl/opts['batch_size'],},pbar)


if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])

optimizer.step()

batch_loss = 0
#
optimizer.zero_grad()
model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])

t3 = time.time()

Expand Down

0 comments on commit 919fa25

Please sign in to comment.