diff --git a/README.md b/README.md index 5ec6a9bd..7344968a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ -PyTorch-VAE +

+ PyTorch VAE
+

+ +| Model | Reconstruction | Samples | +|-------|----------------|---------| +| | | | +| | | | +| | | | TODO - [x] VanillaVAE @@ -15,4 +23,5 @@ TODO - [ ] IWAE - [ ] VLAE - [ ] FactorVAE +- [ ] PixelVAE diff --git a/assets/WAE_IMQ_15.png b/assets/WAE_IMQ_15.png new file mode 100644 index 00000000..e4c3abfd Binary files /dev/null and b/assets/WAE_IMQ_15.png differ diff --git a/assets/WAE_RBF_17.png b/assets/WAE_RBF_17.png new file mode 100644 index 00000000..85b110aa Binary files /dev/null and b/assets/WAE_RBF_17.png differ diff --git a/configs/vae.yaml b/configs/vae.yaml index b146f315..a319cbd9 100644 --- a/configs/vae.yaml +++ b/configs/vae.yaml @@ -3,21 +3,19 @@ model_params: in_channels: 3 latent_dim: 128 -data_params: +exp_params: data_path: "../../shared/Data/" img_size: 64 - batch_size: 144 - -optimizer_params: - LR: 5e-3 + batch_size: 144 # Better to have a square number + LR: 0.005 scheduler_gamma: 0.95 - gpus: 1 - num_epochs: 50 - -logging_params: - save_dir: "logs/", - name: "VanillaVAE", - - +trainer_params: + gpus: [0, 1, 2] + max_nb_epochs: 50 + distributed_backend: 'DP' +logging_params: + save_dir: "logs/" + name: "VanillaVAE" + manual_seed: 1265 diff --git a/configs/wae_md_imq.yaml b/configs/wae_md_imq.yaml new file mode 100644 index 00000000..9eb75457 --- /dev/null +++ b/configs/wae_md_imq.yaml @@ -0,0 +1,27 @@ +model_params: + name: 'WAE_MMD' + in_channels: 3 + latent_dim: 128 + reg_weight: 100 + kernel_type: 'imq' + +exp_params: + data_path: "../../shared/Data/" + img_size: 64 + batch_size: 144 # Better to have a square number + LR: 0.005 + scheduler_gamma: 0.95 + +trainer_params: + gpus: [0, 1, 2] + max_nb_epochs: 50 + distributed_backend: 'DP' + +logging_params: + save_dir: "logs/" + name: "WassersteinVAE" + manual_seed: 1265 + + + + diff --git a/configs/wae_mmd_rbf.yaml b/configs/wae_mmd_rbf.yaml new file mode 100644 index 00000000..c92ec42b --- /dev/null +++ b/configs/wae_mmd_rbf.yaml @@ -0,0 +1,27 @@ +model_params: + name: 'WAE_MMD' + in_channels: 3 + latent_dim: 128 + reg_weight: 100 + kernel_type: 'rbf' + +exp_params: + data_path: "../../shared/Data/" + img_size: 64 + batch_size: 144 # Better to have a square number + LR: 0.005 + scheduler_gamma: 0.95 + +trainer_params: + gpus: [0, 1, 2] + max_nb_epochs: 50 + distributed_backend: 'DP' + +logging_params: + save_dir: "logs/" + name: "WassersteinVAE" + manual_seed: 1265 + + + + diff --git a/experiment.py b/experiment.py index 99c0c38d..6bece210 100644 --- a/experiment.py +++ b/experiment.py @@ -32,7 +32,7 @@ def training_step(self, batch, batch_idx): results = self.forward(real_img, labels = labels) train_loss = self.model.loss_function(*results, - M_N = self.params.batch_size/ self.num_train_imgs ) + M_N = self.params['batch_size']/ self.num_train_imgs ) self.logger.experiment.log({key: val.item() for key, val in train_loss.items()}) @@ -42,7 +42,7 @@ def validation_step(self, batch, batch_idx): real_img, labels = batch results = self.forward(real_img, labels = labels) val_loss = self.model.loss_function(*results, - M_N = self.params.batch_size/ self.num_train_imgs) + M_N = self.params['batch_size']/ self.num_train_imgs) return val_loss @@ -53,7 +53,7 @@ def validation_end(self, outputs): return {'val_loss': avg_loss, 'log': tensorboard_logs} def sample_images(self): - z = torch.randn(self.params.batch_size, + z = torch.randn(self.params['batch_size'], self.model.latent_dim) if self.on_gpu: @@ -64,25 +64,33 @@ def sample_images(self): vutils.save_image(samples.data, f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png", normalize=True, - nrow=int(math.sqrt(self.params.batch_size))) + nrow=int(math.sqrt(self.params['batch_size']))) + test_input = next(iter(self.val_dataloader)) + recons = self.model(test_input) + + vutils.save_image(recons.data, + f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png", + normalize=True, + nrow=int(math.sqrt(self.params['batch_size']))) + del test_input, recons, samples, z def configure_optimizers(self): - optimizer = optim.Adam(self.model.parameters(), lr=self.params.LR) - scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params.scheduler_gamma) + optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR']) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params['scheduler_gamma']) return [optimizer] #, [scheduler] @pl.data_loader def train_dataloader(self): transform = self.data_transforms() - dataset = CelebA(root = self.params.data_path, + dataset = CelebA(root = self.params['data_path'], split = "train", transform=transform, download=False) self.num_train_imgs = len(dataset) return DataLoader(dataset, - batch_size= self.params.batch_size, + batch_size= self.params['batch_size'], shuffle = True, drop_last=True) @@ -90,11 +98,11 @@ def train_dataloader(self): def val_dataloader(self): transform = self.data_transforms() - return DataLoader(CelebA(root = self.params.data_path, + return DataLoader(CelebA(root = self.params['data_path'], split = "test", transform=transform, download=False), - batch_size= self.params.batch_size, + batch_size= self.params['batch_size'], shuffle = True, drop_last=True) @@ -102,7 +110,7 @@ def data_transforms(self): SetRange = transforms.Lambda(lambda X: 2 * X - 1.) transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.CenterCrop(148), - transforms.Resize(self.params.img_size), + transforms.Resize(self.params['img_size']), transforms.ToTensor(), SetRange]) return transform diff --git a/models/__init__.py b/models/__init__.py index 6826f063..b4ec2960 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -9,3 +9,9 @@ VAE = VanillaVAE GaussianVAE = VanillaVAE CVAE = ConditionalVAE + +vae_models = {'VanillaVAE':VanillaVAE, + 'WAE_MMD':WAE_MMD, + 'ConditionalVAE':ConditionalVAE, + 'BetaVAE':BetaVAE, + 'GammaVAE':GammaVAE} diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc index db7d2be4..ff5fd4af 100644 Binary files a/models/__pycache__/__init__.cpython-36.pyc and b/models/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc index 5cca70fd..a5284f22 100644 Binary files a/models/__pycache__/__init__.cpython-37.pyc and b/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/__pycache__/beta_vae.cpython-36.pyc b/models/__pycache__/beta_vae.cpython-36.pyc index 8cbe4f84..70b68482 100644 Binary files a/models/__pycache__/beta_vae.cpython-36.pyc and b/models/__pycache__/beta_vae.cpython-36.pyc differ diff --git a/models/__pycache__/beta_vae.cpython-37.pyc b/models/__pycache__/beta_vae.cpython-37.pyc index 348de533..8cb4bb4a 100644 Binary files a/models/__pycache__/beta_vae.cpython-37.pyc and b/models/__pycache__/beta_vae.cpython-37.pyc differ diff --git a/models/__pycache__/gamma_vae.cpython-36.pyc b/models/__pycache__/gamma_vae.cpython-36.pyc index 9cf988d1..4e7f3b12 100644 Binary files a/models/__pycache__/gamma_vae.cpython-36.pyc and b/models/__pycache__/gamma_vae.cpython-36.pyc differ diff --git a/models/__pycache__/gamma_vae.cpython-37.pyc b/models/__pycache__/gamma_vae.cpython-37.pyc index 1620ad1f..12d00d00 100644 Binary files a/models/__pycache__/gamma_vae.cpython-37.pyc and b/models/__pycache__/gamma_vae.cpython-37.pyc differ diff --git a/models/__pycache__/vanilla_vae.cpython-36.pyc b/models/__pycache__/vanilla_vae.cpython-36.pyc index d7f65211..c4a00a50 100644 Binary files a/models/__pycache__/vanilla_vae.cpython-36.pyc and b/models/__pycache__/vanilla_vae.cpython-36.pyc differ diff --git a/models/__pycache__/vanilla_vae.cpython-37.pyc b/models/__pycache__/vanilla_vae.cpython-37.pyc index b1d49aea..c69a8b0c 100644 Binary files a/models/__pycache__/vanilla_vae.cpython-37.pyc and b/models/__pycache__/vanilla_vae.cpython-37.pyc differ diff --git a/models/__pycache__/wae_mmd.cpython-36.pyc b/models/__pycache__/wae_mmd.cpython-36.pyc index b3711833..85a01724 100644 Binary files a/models/__pycache__/wae_mmd.cpython-36.pyc and b/models/__pycache__/wae_mmd.cpython-36.pyc differ diff --git a/models/beta_vae.py b/models/beta_vae.py index 1f65d6ef..4e9bb3bc 100644 --- a/models/beta_vae.py +++ b/models/beta_vae.py @@ -11,7 +11,8 @@ def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, - beta: int = 1) -> None: + beta: int = 1, + **kwargs) -> None: super(BetaVAE, self).__init__() self.latent_dim = latent_dim diff --git a/models/cvae.py b/models/cvae.py index 2253a379..3bffc64f 100644 --- a/models/cvae.py +++ b/models/cvae.py @@ -12,7 +12,8 @@ def __init__(self, num_classes: int, latent_dim: int, hidden_dims: List = None, - img_size:int = 64) -> None: + img_size:int = 64, + **kwargs) -> None: super(ConditionalVAE, self).__init__() self.latent_dim = latent_dim diff --git a/models/gamma_vae.py b/models/gamma_vae.py index bb40b009..ac26ebcc 100644 --- a/models/gamma_vae.py +++ b/models/gamma_vae.py @@ -14,7 +14,8 @@ class GammaVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, - hidden_dims: List = None) -> None: + hidden_dims: List = None, + **kwargs) -> None: super(GammaVAE, self).__init__() self.latent_dim = latent_dim diff --git a/models/vanilla_vae.py b/models/vanilla_vae.py index 6ff242d9..d9bdbe15 100644 --- a/models/vanilla_vae.py +++ b/models/vanilla_vae.py @@ -10,7 +10,8 @@ class VanillaVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, - hidden_dims: List = None) -> None: + hidden_dims: List = None, + **kwargs) -> None: super(VanillaVAE, self).__init__() self.latent_dim = latent_dim diff --git a/models/wae_mmd.py b/models/wae_mmd.py index 16d3f449..a1be794d 100644 --- a/models/wae_mmd.py +++ b/models/wae_mmd.py @@ -13,7 +13,8 @@ def __init__(self, hidden_dims: List = None, reg_weight: int = 100, kernel_type: str = 'imq', - latent_var: float = 2.) -> None: + latent_var: float = 2., + **kwargs) -> None: super(WAE_MMD, self).__init__() self.latent_dim = latent_dim @@ -133,7 +134,7 @@ def compute_kernel(self, x2 = x2.unsqueeze(-3) # Make it into a row tensor """ - Usually this is not required, especially in our case, + Usually the below lines are not required, especially in our case, but this is useful when x1 and x2 have different sizes along the 0th dimension. """ diff --git a/run.py b/run.py index 00223c8b..817c4a3a 100644 --- a/run.py +++ b/run.py @@ -1,45 +1,49 @@ -import torch -from models import VanillaVAE, WAE_MMD, CVAE +import yaml +import argparse +from models import * from experiment import VAEXperiment +import torch.backends.cudnn as cudnn from pytorch_lightning import Trainer from pytorch_lightning.logging import TestTubeLogger +parser = argparse.ArgumentParser(description='Generic runner for VAE models') +parser.add_argument('--config', '-c', + dest="filename", + metavar='FILE', + help = 'path to the config file', + default='configs/vae.yaml') + +args = parser.parse_args() +with open(args.filename, 'r') as file: + try: + config = yaml.safe_load(file) + except yaml.YAMLError as exc: + print(exc) + + tt_logger = TestTubeLogger( - save_dir="logs/", - name="WassersteinVAE", + save_dir=config['logging_params']['save_dir'], + name=config['logging_params']['name'], debug=False, create_git_tag=False, ) - -class hparams(object): - def __init__(self): - self.LR = 5e-3 - self.scheduler_gamma = 0.95 - self.gpus = 1 - self.data_path = "../../shared/Data/" - self.batch_size = 144 - self.img_size = 64 - self.manual_seed = 1256 - - -hyper_params = hparams() -torch.manual_seed = hyper_params.manual_seed -# model = VanillaVAE(in_channels=3, latent_dim=128) -# model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64) -model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100) +torch.manual_seed = config['logging_params']['manual_seed'] +cudnn.deterministic = True +model = vae_models[config['model_params']['name']](**config['model_params']) +# # model = CVAE(in_channels=3, latent_dim=128, num_classes=40, img_size=64) +# # model = WAE_MMD(in_channels=3, latent_dim=128, reg_weight=100) experiment = VAEXperiment(model, - hyper_params) - + config['exp_params']) -runner = Trainer(gpus=hyper_params.gpus, - default_save_path=f"{tt_logger.save_dir}", +runner = Trainer(default_save_path=f"{tt_logger.save_dir}", min_nb_epochs=1, - max_nb_epochs= 50, logger=tt_logger, log_save_interval=100, train_percent_check=1., - val_percent_check=1.) + val_percent_check=1., + **config['trainer_params']) +print(f"======= Training {config['model_params']['name']} =======") runner.fit(experiment) \ No newline at end of file