Skip to content

Commit

Permalink
Updated config files
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 17, 2020
1 parent e2ed62a commit 5bd2c42
Show file tree
Hide file tree
Showing 23 changed files with 142 additions and 58 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
PyTorch-VAE
<p align="center">
<b>PyTorch VAE</b><br>
</p>

| Model | Reconstruction | Samples |
|-------|----------------|---------|
| | | |
| | | |
| | | |

TODO
- [x] VanillaVAE
Expand All @@ -15,4 +23,5 @@ TODO
- [ ] IWAE
- [ ] VLAE
- [ ] FactorVAE
- [ ] PixelVAE

Binary file added assets/WAE_IMQ_15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/WAE_RBF_17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 11 additions & 13 deletions configs/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@ model_params:
in_channels: 3
latent_dim: 128

data_params:
exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144

optimizer_params:
LR: 5e-3
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95
gpus: 1
num_epochs: 50

logging_params:
save_dir: "logs/",
name: "VanillaVAE",



trainer_params:
gpus: [0, 1, 2]
max_nb_epochs: 50
distributed_backend: 'DP'

logging_params:
save_dir: "logs/"
name: "VanillaVAE"
manual_seed: 1265
27 changes: 27 additions & 0 deletions configs/wae_md_imq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 100
kernel_type: 'imq'

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95

trainer_params:
gpus: [0, 1, 2]
max_nb_epochs: 50
distributed_backend: 'DP'

logging_params:
save_dir: "logs/"
name: "WassersteinVAE"
manual_seed: 1265




27 changes: 27 additions & 0 deletions configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 100
kernel_type: 'rbf'

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95

trainer_params:
gpus: [0, 1, 2]
max_nb_epochs: 50
distributed_backend: 'DP'

logging_params:
save_dir: "logs/"
name: "WassersteinVAE"
manual_seed: 1265




30 changes: 19 additions & 11 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def training_step(self, batch, batch_idx):
results = self.forward(real_img, labels = labels)

train_loss = self.model.loss_function(*results,
M_N = self.params.batch_size/ self.num_train_imgs )
M_N = self.params['batch_size']/ self.num_train_imgs )

self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})

Expand All @@ -42,7 +42,7 @@ def validation_step(self, batch, batch_idx):
real_img, labels = batch
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params.batch_size/ self.num_train_imgs)
M_N = self.params['batch_size']/ self.num_train_imgs)

return val_loss

Expand All @@ -53,7 +53,7 @@ def validation_end(self, outputs):
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def sample_images(self):
z = torch.randn(self.params.batch_size,
z = torch.randn(self.params['batch_size'],
self.model.latent_dim)

if self.on_gpu:
Expand All @@ -64,45 +64,53 @@ def sample_images(self):
vutils.save_image(samples.data,
f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params.batch_size)))
nrow=int(math.sqrt(self.params['batch_size'])))
test_input = next(iter(self.val_dataloader))
recons = self.model(test_input)

vutils.save_image(recons.data,
f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
del test_input, recons, samples, z



def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.params.LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params.scheduler_gamma)
optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR'])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params['scheduler_gamma'])
return [optimizer] #, [scheduler]

@pl.data_loader
def train_dataloader(self):
transform = self.data_transforms()
dataset = CelebA(root = self.params.data_path,
dataset = CelebA(root = self.params['data_path'],
split = "train",
transform=transform,
download=False)
self.num_train_imgs = len(dataset)
return DataLoader(dataset,
batch_size= self.params.batch_size,
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)

@pl.data_loader
def val_dataloader(self):
transform = self.data_transforms()

return DataLoader(CelebA(root = self.params.data_path,
return DataLoader(CelebA(root = self.params['data_path'],
split = "test",
transform=transform,
download=False),
batch_size= self.params.batch_size,
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)

def data_transforms(self):
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(self.params.img_size),
transforms.Resize(self.params['img_size']),
transforms.ToTensor(),
SetRange])
return transform
Expand Down
6 changes: 6 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE

vae_models = {'VanillaVAE':VanillaVAE,
'WAE_MMD':WAE_MMD,
'ConditionalVAE':ConditionalVAE,
'BetaVAE':BetaVAE,
'GammaVAE':GammaVAE}
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/wae_mmd.cpython-36.pyc
Binary file not shown.
3 changes: 2 additions & 1 deletion models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
beta: int = 1) -> None:
beta: int = 1,
**kwargs) -> None:
super(BetaVAE, self).__init__()

self.latent_dim = latent_dim
Expand Down
3 changes: 2 additions & 1 deletion models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(self,
num_classes: int,
latent_dim: int,
hidden_dims: List = None,
img_size:int = 64) -> None:
img_size:int = 64,
**kwargs) -> None:
super(ConditionalVAE, self).__init__()

self.latent_dim = latent_dim
Expand Down
3 changes: 2 additions & 1 deletion models/gamma_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class GammaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None) -> None:
hidden_dims: List = None,
**kwargs) -> None:
super(GammaVAE, self).__init__()
self.latent_dim = latent_dim

Expand Down
3 changes: 2 additions & 1 deletion models/vanilla_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class VanillaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None) -> None:
hidden_dims: List = None,
**kwargs) -> None:
super(VanillaVAE, self).__init__()

self.latent_dim = latent_dim
Expand Down
5 changes: 3 additions & 2 deletions models/wae_mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def __init__(self,
hidden_dims: List = None,
reg_weight: int = 100,
kernel_type: str = 'imq',
latent_var: float = 2.) -> None:
latent_var: float = 2.,
**kwargs) -> None:
super(WAE_MMD, self).__init__()

self.latent_dim = latent_dim
Expand Down Expand Up @@ -133,7 +134,7 @@ def compute_kernel(self,
x2 = x2.unsqueeze(-3) # Make it into a row tensor

"""
Usually this is not required, especially in our case,
Usually the below lines are not required, especially in our case,
but this is useful when x1 and x2 have different sizes
along the 0th dimension.
"""
Expand Down
58 changes: 31 additions & 27 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,49 @@
import torch
from models import VanillaVAE, WAE_MMD, CVAE
import yaml
import argparse
from models import *
from experiment import VAEXperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TestTubeLogger


parser = argparse.ArgumentParser(description='Generic runner for VAE models')
parser.add_argument('--config', '-c',
dest="filename",
metavar='FILE',
help = 'path to the config file',
default='configs/vae.yaml')

args = parser.parse_args()
with open(args.filename, 'r') as file:
try:
config = yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)


tt_logger = TestTubeLogger(
save_dir="logs/",
name="WassersteinVAE",
save_dir=config['logging_params']['save_dir'],
name=config['logging_params']['name'],
debug=False,
create_git_tag=False,
)


class hparams(object):
def __init__(self):
self.LR = 5e-3
self.scheduler_gamma = 0.95
self.gpus = 1
self.data_path = "../../shared/Data/"
self.batch_size = 144
self.img_size = 64
self.manual_seed = 1256


hyper_params = hparams()
torch.manual_seed = hyper_params.manual_seed
# model = VanillaVAE(in_channels=3, latent_dim=128)
# model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64)
model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100)
torch.manual_seed = config['logging_params']['manual_seed']
cudnn.deterministic = True
model = vae_models[config['model_params']['name']](**config['model_params'])
# # model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64)
# # model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100)
experiment = VAEXperiment(model,
hyper_params)

config['exp_params'])

runner = Trainer(gpus=hyper_params.gpus,
default_save_path=f"{tt_logger.save_dir}",
runner = Trainer(default_save_path=f"{tt_logger.save_dir}",
min_nb_epochs=1,
max_nb_epochs= 50,
logger=tt_logger,
log_save_interval=100,
train_percent_check=1.,
val_percent_check=1.)
val_percent_check=1.,
**config['trainer_params'])

print(f"======= Training {config['model_params']['name']} =======")
runner.fit(experiment)

0 comments on commit 5bd2c42

Please sign in to comment.