Skip to content

Commit

Permalink
Add plotting code
Browse files Browse the repository at this point in the history
  • Loading branch information
hrishi508 committed Jun 29, 2022
1 parent 7bb3bf8 commit 4e06102
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 6 deletions.
21 changes: 17 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tqdm import tqdm
import argparse
from datetime import datetime
import matplotlib.pyplot as plt

import torch
from torch import nn
Expand All @@ -17,7 +18,7 @@

def train(dataloader, model, loss_fn_arr, train_loss_arr, optimizer, scheduler, cfg):
# size = len(dataloader.dataset)
size = 100 # size of dataset
size = 20 # size of dataset
num_batches = len(dataloader)
batch_size = int(size/num_batches)

Expand Down Expand Up @@ -143,7 +144,7 @@ def train(dataloader, model, loss_fn_arr, train_loss_arr, optimizer, scheduler,

def test(dataloader, model, loss_fn_arr, test_loss_arr, cfg):
# size = len(dataloader.dataset)
size = 100 # size of dataset
size = 20 # size of dataset
num_batches = len(dataloader)
batch_size = int(size/num_batches)
test_loss = 0
Expand Down Expand Up @@ -275,7 +276,7 @@ def main(args):

# creating a random dataset (same shape as the facial dataset we will be using) for testing the code logic
dataloader = []
for i in range(10):
for i in range(2):
X_tmp = torch.randn((10, 3, 112, 112))
# y = torch.tensor([[0, 1, 2, 0], [0, 1, 2, 0], [0, 1, 2, 0]])
# assuming 4 classes each for gender, age, race and id
Expand All @@ -300,7 +301,19 @@ def main(args):
if cfg.save_model_weights_every > 0 and (t + 1)%cfg.save_model_weights_every == 0:
now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
torch.save(model.state_dict(), cfg.model_weights_dir + dt_string + f"weights_epoch_{t+1}.pth")
torch.save(model.state_dict(), cfg.model_weights_dir + dt_string + f"debface_epoch_{t+1}_trial_" + cfg.trial_number + ".pth")

if cfg.plot_losses:
x = [i+1 for i in range(cfg.num_epoch)]
plt.plot(x, train_loss_arr, 'g', label='train')
plt.plot(x, test_loss_arr, 'r', label='test')
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()

now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
plt.savefig(cfg.plots_dir + dt_string + "debface_trial_" + cfg.trial_number + ".png")

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down
14 changes: 13 additions & 1 deletion train_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,19 @@ def main(args):
if cfg.save_model_weights_every > 0 and (t + 1)%cfg.save_model_weights_every == 0:
now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
torch.save(model.decoder.state_dict(), cfg.model_weights_dir + dt_string + f"decoder_weights_epoch_{t+1}.pth")
torch.save(model.decoder.state_dict(), cfg.model_weights_dir + dt_string + f"decoder_epoch_{t+1}_trial_" + cfg.trial_number + ".pth")

if cfg.plot_losses:
x = [i+1 for i in range(cfg.num_epoch)]
plt.plot(x, train_loss_arr, 'g', label='train')
plt.plot(x, test_loss_arr, 'r', label='test')
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()

now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
plt.savefig(cfg.plots_dir + dt_string + "autoencoder_trial_" + cfg.trial_number + ".png")

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down
14 changes: 13 additions & 1 deletion train_without_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,19 @@ def main(args):
if cfg.save_model_weights_every > 0 and (t + 1)%cfg.save_model_weights_every == 0:
now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
torch.save(model.state_dict(), cfg.model_weights_dir + dt_string + f"weights_epoch_{t+1}.pth")
torch.save(model.state_dict(), cfg.model_weights_dir + dt_string + f"debface_epoch_{t+1}_trial_" + cfg.trial_number + ".pth")

if cfg.plot_losses:
x = [i+1 for i in range(cfg.num_epoch)]
plt.plot(x, train_loss_arr, 'g', label='train')
plt.plot(x, test_loss_arr, 'r', label='test')
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()

now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H:%M:%S_")
plt.savefig(cfg.plots_dir + dt_string + "debface_trial_" + cfg.trial_number + ".png")

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down
3 changes: 3 additions & 0 deletions utils/utils_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __init__(self, cfg):
self.save_model_weights_every = int(cfg["save_model_weights_every"])
self.load_weights = (cfg["load_weights"] == "True")
self.load_weights_file = cfg["load_weights_file"]
self.plots_dir = cfg["plots_dir"]
self.plot_losses = (cfg["plot_losses"] == "True")
self.trial_number = cfg["trial_number"]

def get_config(config_file):
config_obj = configparser.ConfigParser(inline_comment_prefixes="#")
Expand Down

0 comments on commit 4e06102

Please sign in to comment.