Skip to content

Commit

Permalink
COMMIT
Browse files Browse the repository at this point in the history
  • Loading branch information
MickyasTA committed Apr 25, 2024
1 parent f0aa743 commit bee9fb0
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions VAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda, Compose
from torch import nn
from torch.utils.tensorboard import SummaryWriter # Tensorboard support


# Network parameters
batch_size=64 # Higher bacth size
epoch=3
epoch=1000
learning_rate=1e-3


Expand Down Expand Up @@ -71,13 +72,22 @@ def forward(self, x):
# Mean square error is taken as loss function here as the pixel intensity is continuous
loss_func = nn.MSELoss() # One possible loss criterion


def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
#-------------tensorboard --logdir=runs----------------
tb = SummaryWriter()
network = VAE()
images,lables=next(iter(training_dataloader))
grid = torchvision.utils.make_grid(images)
tb.add_image('images',grid)
tb.add_graph(network,images)
#-------------tensorboard --logdir=runs----------------
if model_path and os.path.exists(model_path):
model = VAE()
model.load_state_dict(torch.load(model_path))
print("Loaded PyTorch Model State from model.pth")
model.eval()

# run the inference on the test data and display the results
def unnormalize(img):
img = (img * 0.5) + 0.5 # Reverse normalization
Expand Down Expand Up @@ -110,15 +120,18 @@ def unnormalize(img):
else:
print("No model.pth found, training a new model...")
print("Training the model")




def traning_loop(training_dataloader,optimizer,loss_func,model):
size=len(training_dataloader.dataset)
model.train()

for index ,(actual_data,ground_truth) in enumerate(training_dataloader):
actual_data,ground_truth=actual_data.to(device),ground_truth.to(device)
#print(actual_data.shape, ground_truth.shape)

total_loss=0
total_correct=0
# Forward Propaget
reconstructed_data,_=model(actual_data)
#print(reconstructed_data.shape, actual_data.shape)
Expand All @@ -131,10 +144,25 @@ def traning_loop(training_dataloader,optimizer,loss_func,model):
optimizer.step()
optimizer.zero_grad()
if index %100==0:
loss,current=loss.item(),index*len(actual_data)+len(actual_data)
print(f"loss:{loss} current:{current}")


#loss,current=loss+loss.item(),index*len(actual_data)+len(actual_data)
#total_loss,total_correct=total_loss+loss.item(),total_correct+get_num_correct(reconstructed_data,actual_data)
total_loss,total_correct=total_loss+loss.item(),index*len(actual_data)+len(actual_data)

tb.add_scalar("Loss",total_loss,epoch)
tb.add_scalar("Number Correct",total_correct,epoch)
tb.add_scalar("Accuracy",total_correct/size,epoch)

tb.add_histogram("conv1.bias",model.encoder_network[0].bias,epoch)
tb.add_histogram("conv1.weight",model.encoder_network[0].weight,epoch)
tb.add_histogram("conv2.bias",model.encoder_network[3].bias,epoch)
tb.add_histogram("conv2.weight",model.encoder_network[3].weight,epoch)
tb.add_histogram("deconv1.bias",model.decoder_network[0].bias,epoch)
tb.add_histogram("deconv1.weight",model.decoder_network[0].weight,epoch)
tb.add_histogram("deconv2.bias",model.decoder_network[2].bias,epoch)
tb.add_histogram("deconv2.weight",model.decoder_network[2].weight,epoch)
tb.close()
#print(f"loss:{total_loss} total_correct:{total_correct}")
print(f"loss:{total_loss} total_correct:{total_correct}")
for t in range(epoch):
print(f" Epoch {t+1}\n--------------------------------------------")
traning_loop(training_dataloader=training_dataloader,
Expand Down Expand Up @@ -182,3 +210,5 @@ def unnormalize(img):
axes[1, i].axis('off')
plt.tight_layout()
plt.show()
# Run tensorboard to visualize the training process
# tensorboard --logdir=runs
Binary file removed finding-lanes/conv_autoencoder.pth
Binary file not shown.
Binary file not shown.

0 comments on commit bee9fb0

Please sign in to comment.