Skip to content

Commit

Permalink
Updated experiment setup
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 16, 2020
1 parent 20c4dfa commit 79520be
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 258 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ logs/
VanillaVAE/version_0/

__pycache__/
.ipynb_checkpoints/
241 changes: 0 additions & 241 deletions .ipynb_checkpoints/Run-checkpoint.ipynb

This file was deleted.

2 changes: 2 additions & 0 deletions LICENSE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright Anand Krishnamoorthy Subramanian
[email protected]

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

Expand Down
23 changes: 12 additions & 11 deletions experiment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
import torch
import pytorch_lightning as pl
from torch import optim
from models import BaseVAE
from models.types_ import *
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA
from torch import optim
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from models.types_ import *
import math



class VAEXperiment(pl.LightningModule):
Expand All @@ -21,14 +22,14 @@ def __init__(self,
self.params = params
self.curr_device = None

def forward(self, input: Tensor):
return self.model(input)
def forward(self, input: Tensor, **kwargs) -> Tensor:
return self.model(input, **kwargs)

def training_step(self, batch, batch_idx):
real_img, _ = batch
real_img, labels = batch
self.curr_device = real_img.device

results = self.forward(real_img)
results = self.forward(real_img, labels = labels)

train_loss = self.model.loss_function(*results,
M_N = self.params.batch_size/ self.num_train_imgs )
Expand All @@ -38,8 +39,8 @@ def training_step(self, batch, batch_idx):
return train_loss

def validation_step(self, batch, batch_idx):
real_img, _ = batch
results = self.forward(real_img)
real_img, labels = batch
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params.batch_size/ self.num_train_imgs)

Expand Down
2 changes: 1 addition & 1 deletion models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor) -> Tensor:
def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
Expand Down
3 changes: 2 additions & 1 deletion models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, y: Tensor) -> List[Tensor]:
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
y = kwargs['labels']
embedded_class = self.embed_class(y)
embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
embedded_input = self.embed_data(input)
Expand Down
2 changes: 1 addition & 1 deletion models/gamma_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor) -> Tensor:
def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
Expand Down
2 changes: 1 addition & 1 deletion models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor) -> List[Tensor]:
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
Expand Down
2 changes: 1 addition & 1 deletion models/wae_mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def decode(self, z: Tensor) -> Tensor:
result = self.final_layer(result)
return result

def forward(self, input: Tensor) -> List[Tensor]:
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
z = self.encode(input)
return [self.decode(z), input, z]

Expand Down
3 changes: 2 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from models import VanillaVAE, WAE_MMD
from models import VanillaVAE, WAE_MMD, CVAE
from experiment import VAEXperiment
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TestTubeLogger
Expand Down Expand Up @@ -27,6 +27,7 @@ def __init__(self):
hyper_params = hparams()
torch.manual_seed = hyper_params.manual_seed
# model = VanillaVAE(in_channels=3, latent_dim=128)
# model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64)
model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100)
experiment = VAEXperiment(model,
hyper_params)
Expand Down

0 comments on commit 79520be

Please sign in to comment.