diff --git a/README.md b/README.md
index 5ec6a9bd..7344968a 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,12 @@
-PyTorch-VAE
+
+ PyTorch VAE
+
+
+| Model | Reconstruction | Samples |
+|-------|----------------|---------|
+| | | |
+| | | |
+| | | |
TODO
- [x] VanillaVAE
@@ -15,4 +23,5 @@ TODO
- [ ] IWAE
- [ ] VLAE
- [ ] FactorVAE
+- [ ] PixelVAE
diff --git a/assets/WAE_IMQ_15.png b/assets/WAE_IMQ_15.png
new file mode 100644
index 00000000..e4c3abfd
Binary files /dev/null and b/assets/WAE_IMQ_15.png differ
diff --git a/assets/WAE_RBF_17.png b/assets/WAE_RBF_17.png
new file mode 100644
index 00000000..85b110aa
Binary files /dev/null and b/assets/WAE_RBF_17.png differ
diff --git a/configs/vae.yaml b/configs/vae.yaml
index b146f315..a319cbd9 100644
--- a/configs/vae.yaml
+++ b/configs/vae.yaml
@@ -3,21 +3,19 @@ model_params:
in_channels: 3
latent_dim: 128
-data_params:
+exp_params:
data_path: "../../shared/Data/"
img_size: 64
- batch_size: 144
-
-optimizer_params:
- LR: 5e-3
+ batch_size: 144 # Better to have a square number
+ LR: 0.005
scheduler_gamma: 0.95
- gpus: 1
- num_epochs: 50
-
-logging_params:
- save_dir: "logs/",
- name: "VanillaVAE",
-
-
+trainer_params:
+ gpus: [0, 1, 2]
+ max_nb_epochs: 50
+ distributed_backend: 'DP'
+logging_params:
+ save_dir: "logs/"
+ name: "VanillaVAE"
+ manual_seed: 1265
diff --git a/configs/wae_md_imq.yaml b/configs/wae_md_imq.yaml
new file mode 100644
index 00000000..9eb75457
--- /dev/null
+++ b/configs/wae_md_imq.yaml
@@ -0,0 +1,27 @@
+model_params:
+ name: 'WAE_MMD'
+ in_channels: 3
+ latent_dim: 128
+ reg_weight: 100
+ kernel_type: 'imq'
+
+exp_params:
+ data_path: "../../shared/Data/"
+ img_size: 64
+ batch_size: 144 # Better to have a square number
+ LR: 0.005
+ scheduler_gamma: 0.95
+
+trainer_params:
+ gpus: [0, 1, 2]
+ max_nb_epochs: 50
+ distributed_backend: 'DP'
+
+logging_params:
+ save_dir: "logs/"
+ name: "WassersteinVAE"
+ manual_seed: 1265
+
+
+
+
diff --git a/configs/wae_mmd_rbf.yaml b/configs/wae_mmd_rbf.yaml
new file mode 100644
index 00000000..c92ec42b
--- /dev/null
+++ b/configs/wae_mmd_rbf.yaml
@@ -0,0 +1,27 @@
+model_params:
+ name: 'WAE_MMD'
+ in_channels: 3
+ latent_dim: 128
+ reg_weight: 100
+ kernel_type: 'rbf'
+
+exp_params:
+ data_path: "../../shared/Data/"
+ img_size: 64
+ batch_size: 144 # Better to have a square number
+ LR: 0.005
+ scheduler_gamma: 0.95
+
+trainer_params:
+ gpus: [0, 1, 2]
+ max_nb_epochs: 50
+ distributed_backend: 'DP'
+
+logging_params:
+ save_dir: "logs/"
+ name: "WassersteinVAE"
+ manual_seed: 1265
+
+
+
+
diff --git a/experiment.py b/experiment.py
index 99c0c38d..6bece210 100644
--- a/experiment.py
+++ b/experiment.py
@@ -32,7 +32,7 @@ def training_step(self, batch, batch_idx):
results = self.forward(real_img, labels = labels)
train_loss = self.model.loss_function(*results,
- M_N = self.params.batch_size/ self.num_train_imgs )
+ M_N = self.params['batch_size']/ self.num_train_imgs )
self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})
@@ -42,7 +42,7 @@ def validation_step(self, batch, batch_idx):
real_img, labels = batch
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
- M_N = self.params.batch_size/ self.num_train_imgs)
+ M_N = self.params['batch_size']/ self.num_train_imgs)
return val_loss
@@ -53,7 +53,7 @@ def validation_end(self, outputs):
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def sample_images(self):
- z = torch.randn(self.params.batch_size,
+ z = torch.randn(self.params['batch_size'],
self.model.latent_dim)
if self.on_gpu:
@@ -64,25 +64,33 @@ def sample_images(self):
vutils.save_image(samples.data,
f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
normalize=True,
- nrow=int(math.sqrt(self.params.batch_size)))
+ nrow=int(math.sqrt(self.params['batch_size'])))
+ test_input = next(iter(self.val_dataloader))
+ recons = self.model(test_input)
+
+ vutils.save_image(recons.data,
+ f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png",
+ normalize=True,
+ nrow=int(math.sqrt(self.params['batch_size'])))
+ del test_input, recons, samples, z
def configure_optimizers(self):
- optimizer = optim.Adam(self.model.parameters(), lr=self.params.LR)
- scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params.scheduler_gamma)
+ optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR'])
+ scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params['scheduler_gamma'])
return [optimizer] #, [scheduler]
@pl.data_loader
def train_dataloader(self):
transform = self.data_transforms()
- dataset = CelebA(root = self.params.data_path,
+ dataset = CelebA(root = self.params['data_path'],
split = "train",
transform=transform,
download=False)
self.num_train_imgs = len(dataset)
return DataLoader(dataset,
- batch_size= self.params.batch_size,
+ batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)
@@ -90,11 +98,11 @@ def train_dataloader(self):
def val_dataloader(self):
transform = self.data_transforms()
- return DataLoader(CelebA(root = self.params.data_path,
+ return DataLoader(CelebA(root = self.params['data_path'],
split = "test",
transform=transform,
download=False),
- batch_size= self.params.batch_size,
+ batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)
@@ -102,7 +110,7 @@ def data_transforms(self):
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
- transforms.Resize(self.params.img_size),
+ transforms.Resize(self.params['img_size']),
transforms.ToTensor(),
SetRange])
return transform
diff --git a/models/__init__.py b/models/__init__.py
index 6826f063..b4ec2960 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -9,3 +9,9 @@
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
+
+vae_models = {'VanillaVAE':VanillaVAE,
+ 'WAE_MMD':WAE_MMD,
+ 'ConditionalVAE':ConditionalVAE,
+ 'BetaVAE':BetaVAE,
+ 'GammaVAE':GammaVAE}
diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc
index db7d2be4..ff5fd4af 100644
Binary files a/models/__pycache__/__init__.cpython-36.pyc and b/models/__pycache__/__init__.cpython-36.pyc differ
diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc
index 5cca70fd..a5284f22 100644
Binary files a/models/__pycache__/__init__.cpython-37.pyc and b/models/__pycache__/__init__.cpython-37.pyc differ
diff --git a/models/__pycache__/beta_vae.cpython-36.pyc b/models/__pycache__/beta_vae.cpython-36.pyc
index 8cbe4f84..70b68482 100644
Binary files a/models/__pycache__/beta_vae.cpython-36.pyc and b/models/__pycache__/beta_vae.cpython-36.pyc differ
diff --git a/models/__pycache__/beta_vae.cpython-37.pyc b/models/__pycache__/beta_vae.cpython-37.pyc
index 348de533..8cb4bb4a 100644
Binary files a/models/__pycache__/beta_vae.cpython-37.pyc and b/models/__pycache__/beta_vae.cpython-37.pyc differ
diff --git a/models/__pycache__/gamma_vae.cpython-36.pyc b/models/__pycache__/gamma_vae.cpython-36.pyc
index 9cf988d1..4e7f3b12 100644
Binary files a/models/__pycache__/gamma_vae.cpython-36.pyc and b/models/__pycache__/gamma_vae.cpython-36.pyc differ
diff --git a/models/__pycache__/gamma_vae.cpython-37.pyc b/models/__pycache__/gamma_vae.cpython-37.pyc
index 1620ad1f..12d00d00 100644
Binary files a/models/__pycache__/gamma_vae.cpython-37.pyc and b/models/__pycache__/gamma_vae.cpython-37.pyc differ
diff --git a/models/__pycache__/vanilla_vae.cpython-36.pyc b/models/__pycache__/vanilla_vae.cpython-36.pyc
index d7f65211..c4a00a50 100644
Binary files a/models/__pycache__/vanilla_vae.cpython-36.pyc and b/models/__pycache__/vanilla_vae.cpython-36.pyc differ
diff --git a/models/__pycache__/vanilla_vae.cpython-37.pyc b/models/__pycache__/vanilla_vae.cpython-37.pyc
index b1d49aea..c69a8b0c 100644
Binary files a/models/__pycache__/vanilla_vae.cpython-37.pyc and b/models/__pycache__/vanilla_vae.cpython-37.pyc differ
diff --git a/models/__pycache__/wae_mmd.cpython-36.pyc b/models/__pycache__/wae_mmd.cpython-36.pyc
index b3711833..85a01724 100644
Binary files a/models/__pycache__/wae_mmd.cpython-36.pyc and b/models/__pycache__/wae_mmd.cpython-36.pyc differ
diff --git a/models/beta_vae.py b/models/beta_vae.py
index 1f65d6ef..4e9bb3bc 100644
--- a/models/beta_vae.py
+++ b/models/beta_vae.py
@@ -11,7 +11,8 @@ def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
- beta: int = 1) -> None:
+ beta: int = 1,
+ **kwargs) -> None:
super(BetaVAE, self).__init__()
self.latent_dim = latent_dim
diff --git a/models/cvae.py b/models/cvae.py
index 2253a379..3bffc64f 100644
--- a/models/cvae.py
+++ b/models/cvae.py
@@ -12,7 +12,8 @@ def __init__(self,
num_classes: int,
latent_dim: int,
hidden_dims: List = None,
- img_size:int = 64) -> None:
+ img_size:int = 64,
+ **kwargs) -> None:
super(ConditionalVAE, self).__init__()
self.latent_dim = latent_dim
diff --git a/models/gamma_vae.py b/models/gamma_vae.py
index bb40b009..ac26ebcc 100644
--- a/models/gamma_vae.py
+++ b/models/gamma_vae.py
@@ -14,7 +14,8 @@ class GammaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
- hidden_dims: List = None) -> None:
+ hidden_dims: List = None,
+ **kwargs) -> None:
super(GammaVAE, self).__init__()
self.latent_dim = latent_dim
diff --git a/models/vanilla_vae.py b/models/vanilla_vae.py
index 6ff242d9..d9bdbe15 100644
--- a/models/vanilla_vae.py
+++ b/models/vanilla_vae.py
@@ -10,7 +10,8 @@ class VanillaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
- hidden_dims: List = None) -> None:
+ hidden_dims: List = None,
+ **kwargs) -> None:
super(VanillaVAE, self).__init__()
self.latent_dim = latent_dim
diff --git a/models/wae_mmd.py b/models/wae_mmd.py
index 16d3f449..a1be794d 100644
--- a/models/wae_mmd.py
+++ b/models/wae_mmd.py
@@ -13,7 +13,8 @@ def __init__(self,
hidden_dims: List = None,
reg_weight: int = 100,
kernel_type: str = 'imq',
- latent_var: float = 2.) -> None:
+ latent_var: float = 2.,
+ **kwargs) -> None:
super(WAE_MMD, self).__init__()
self.latent_dim = latent_dim
@@ -133,7 +134,7 @@ def compute_kernel(self,
x2 = x2.unsqueeze(-3) # Make it into a row tensor
"""
- Usually this is not required, especially in our case,
+ Usually the below lines are not required, especially in our case,
but this is useful when x1 and x2 have different sizes
along the 0th dimension.
"""
diff --git a/run.py b/run.py
index 00223c8b..817c4a3a 100644
--- a/run.py
+++ b/run.py
@@ -1,45 +1,49 @@
-import torch
-from models import VanillaVAE, WAE_MMD, CVAE
+import yaml
+import argparse
+from models import *
from experiment import VAEXperiment
+import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.logging import TestTubeLogger
+parser = argparse.ArgumentParser(description='Generic runner for VAE models')
+parser.add_argument('--config', '-c',
+ dest="filename",
+ metavar='FILE',
+ help = 'path to the config file',
+ default='configs/vae.yaml')
+
+args = parser.parse_args()
+with open(args.filename, 'r') as file:
+ try:
+ config = yaml.safe_load(file)
+ except yaml.YAMLError as exc:
+ print(exc)
+
+
tt_logger = TestTubeLogger(
- save_dir="logs/",
- name="WassersteinVAE",
+ save_dir=config['logging_params']['save_dir'],
+ name=config['logging_params']['name'],
debug=False,
create_git_tag=False,
)
-
-class hparams(object):
- def __init__(self):
- self.LR = 5e-3
- self.scheduler_gamma = 0.95
- self.gpus = 1
- self.data_path = "../../shared/Data/"
- self.batch_size = 144
- self.img_size = 64
- self.manual_seed = 1256
-
-
-hyper_params = hparams()
-torch.manual_seed = hyper_params.manual_seed
-# model = VanillaVAE(in_channels=3, latent_dim=128)
-# model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64)
-model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100)
+torch.manual_seed = config['logging_params']['manual_seed']
+cudnn.deterministic = True
+model = vae_models[config['model_params']['name']](**config['model_params'])
+# # model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64)
+# # model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100)
experiment = VAEXperiment(model,
- hyper_params)
-
+ config['exp_params'])
-runner = Trainer(gpus=hyper_params.gpus,
- default_save_path=f"{tt_logger.save_dir}",
+runner = Trainer(default_save_path=f"{tt_logger.save_dir}",
min_nb_epochs=1,
- max_nb_epochs= 50,
logger=tt_logger,
log_save_interval=100,
train_percent_check=1.,
- val_percent_check=1.)
+ val_percent_check=1.,
+ **config['trainer_params'])
+print(f"======= Training {config['model_params']['name']} =======")
runner.fit(experiment)
\ No newline at end of file