diff --git a/.gitignore b/.gitignore index b75acb16..3ad9ed51 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -data/ +Data/ logs/ VanillaVAE/version_0/ diff --git a/README.md b/README.md index dc0cf7a8..487f3f70 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ $ cd PyTorch-VAE $ python run.py -c configs/ ``` **Config file template** + ```yaml model_params: name: "" @@ -48,10 +49,15 @@ model_params: . . -exp_params: +data_params: data_path: "" - img_size: 64 # Models are designed to work for this size - batch_size: 64 # Better to have a square number + train_batch_size: 64 # Better to have a square number + val_batch_size: 64 + patch_size: 64 # Models are designed to work for this size + num_workers: 4 + +exp_params: + manual_seed: 1265 LR: 0.005 weight_decay: . # Other arguments required for training, like scheduler etc. @@ -60,7 +66,7 @@ exp_params: trainer_params: gpus: 1 - max_nb_epochs: 50 + max_epochs: 100 gradient_clip_val: 1.5 . . @@ -69,15 +75,17 @@ trainer_params: logging_params: save_dir: "logs/" name: "" - manual_seed: ``` **View TensorBoard Logs** ``` $ cd logs//version_ -$ tensorboard --logdir tf +$ tensorboard --logdir . ``` +**Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference. + + ----

Results
diff --git a/configs/bbvae.yaml b/configs/bbvae.yaml index 9f9ef240..b1baba39 100644 --- a/configs/bbvae.yaml +++ b/configs/bbvae.yaml @@ -7,21 +7,25 @@ model_params: max_capacity: 25 Capacity_max_iter: 10000 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number - LR: 0.0005 + LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "BetaVAE_B" manual_seed: 1265 + name: 'BetaVAE' diff --git a/configs/betatc_vae.yaml b/configs/betatc_vae.yaml index 67b79f2f..e1072d3e 100644 --- a/configs/betatc_vae.yaml +++ b/configs/betatc_vae.yaml @@ -7,21 +7,25 @@ model_params: beta: 6. gamma: 1. +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number - LR: 0.001 + LR: 0.005 weight_decay: 0.0 -# scheduler_gamma: 0.99 + scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "BetaTCVAE" - manual_seed: 1265 + name: 'BetaTCVAE' diff --git a/configs/bhvae.yaml b/configs/bhvae.yaml index 65adcc7b..e9777abe 100644 --- a/configs/bhvae.yaml +++ b/configs/bhvae.yaml @@ -5,21 +5,25 @@ model_params: loss_type: 'H' beta: 10. +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number - LR: 0.0005 + LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "BetaVAE_H" - manual_seed: 1265 + name: 'BetaVAE' diff --git a/configs/cat_vae.yaml b/configs/cat_vae.yaml index d67eaae8..625d2344 100644 --- a/configs/cat_vae.yaml +++ b/configs/cat_vae.yaml @@ -8,21 +8,25 @@ model_params: anneal_interval: 100 alpha: 1.0 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: gpus: [1] - max_nb_epochs: 50 - max_epochs: 50 + max_epochs: 10 logging_params: save_dir: "logs/" name: "CategoricalVAE" - manual_seed: 1265 diff --git a/configs/cvae.yaml b/configs/cvae.yaml index 9b3ab62b..df950a27 100644 --- a/configs/cvae.yaml +++ b/configs/cvae.yaml @@ -4,21 +4,25 @@ model_params: num_classes: 40 latent_dim: 128 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "ConditionalVAE" - manual_seed: 1265 + name: "ConditionalVAE" \ No newline at end of file diff --git a/configs/dfc_vae.yaml b/configs/dfc_vae.yaml index 755bf8a8..a02f85f0 100644 --- a/configs/dfc_vae.yaml +++ b/configs/dfc_vae.yaml @@ -3,21 +3,25 @@ model_params: in_channels: 3 latent_dim: 128 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "DFCVAE" - manual_seed: 1265 diff --git a/configs/dip_vae.yaml b/configs/dip_vae.yaml index f13593ce..e4308f8e 100644 --- a/configs/dip_vae.yaml +++ b/configs/dip_vae.yaml @@ -6,19 +6,24 @@ model_params: lambda_offdiag: 0.1 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.001 weight_decay: 0.0 scheduler_gamma: 0.97 + kld_weight: 1 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" diff --git a/configs/factorvae.yaml b/configs/factorvae.yaml index 29dd0878..fe0bf09c 100644 --- a/configs/factorvae.yaml +++ b/configs/factorvae.yaml @@ -4,25 +4,31 @@ model_params: latent_dim: 128 gamma: 6.4 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" submodel: 'discriminator' retain_first_backpass: True - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 - scheduler_gamma: 0.95 LR_2: 0.005 scheduler_gamma_2: 0.95 + scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [3] - max_nb_epochs: 30 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "FactorVAE" - manual_seed: 1265 + name: "FactorVAE" + + diff --git a/configs/gammavae.yaml b/configs/gammavae.yaml index 465d7d97..8d449ef7 100644 --- a/configs/gammavae.yaml +++ b/configs/gammavae.yaml @@ -6,22 +6,27 @@ model_params: prior_shape: 2. prior_rate: 1. + +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.003 - weight_decay: 0.0005 - + weight_decay: 0.00005 + scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 + gpus: [1] + max_epochs: 10 gradient_clip_val: 0.8 - max_epochs: 50 logging_params: save_dir: "logs/" name: "GammaVAE" - manual_seed: 1265 diff --git a/configs/hvae.yaml b/configs/hvae.yaml index 6f331178..9a26ef11 100644 --- a/configs/hvae.yaml +++ b/configs/hvae.yaml @@ -5,21 +5,25 @@ model_params: latent2_dim: 64 pseudo_input_size: 128 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "VampVAE" - manual_seed: 1265 diff --git a/configs/infovae.yaml b/configs/infovae.yaml index e2b700c2..2a1e3118 100644 --- a/configs/infovae.yaml +++ b/configs/infovae.yaml @@ -7,19 +7,24 @@ model_params: alpha: -9.0 # KLD weight beta: 10.5 # Reconstruction weight +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [3] - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 gradient_clip_val: 0.8 logging_params: diff --git a/configs/iwae.yaml b/configs/iwae.yaml index efad623e..efcc0b5b 100644 --- a/configs/iwae.yaml +++ b/configs/iwae.yaml @@ -4,21 +4,25 @@ model_params: latent_dim: 128 num_samples: 5 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.007 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [3] - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "IWAE" - manual_seed: 1265 diff --git a/configs/joint_vae.yaml b/configs/joint_vae.yaml index e20b81f5..c9ed632b 100644 --- a/configs/joint_vae.yaml +++ b/configs/joint_vae.yaml @@ -16,21 +16,26 @@ model_params: anneal_interval: 100 alpha: 10.0 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: gpus: [1] - max_nb_epochs: 50 - max_epochs: 50 + max_epochs: 10 logging_params: save_dir: "logs/" name: "JointVAE" - manual_seed: 1265 + diff --git a/configs/logcosh_vae.yaml b/configs/logcosh_vae.yaml index 468d7a13..ed91b273 100644 --- a/configs/logcosh_vae.yaml +++ b/configs/logcosh_vae.yaml @@ -5,21 +5,26 @@ model_params: alpha: 10.0 beta: 1.0 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.97 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "LogCoshVAE" - manual_seed: 1265 + diff --git a/configs/lvae.yaml b/configs/lvae.yaml index 1983ee04..df73baa6 100644 --- a/configs/lvae.yaml +++ b/configs/lvae.yaml @@ -4,22 +4,25 @@ model_params: latent_dims: [4,8,16,32,128] hidden_dims: [32, 64,128, 256, 512] +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [2] - max_nb_epochs: 50 - max_epochs: 30 - gradient_clip_val: .5 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "LVAE" - manual_seed: 1265 diff --git a/configs/miwae.yaml b/configs/miwae.yaml index c2c95413..5c0544b7 100644 --- a/configs/miwae.yaml +++ b/configs/miwae.yaml @@ -5,21 +5,26 @@ model_params: num_samples: 5 num_estimates: 3 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 30 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "MIWAE" - manual_seed: 1265 + diff --git a/configs/mssim_vae.yaml b/configs/mssim_vae.yaml index 376f6e2b..61e6354c 100644 --- a/configs/mssim_vae.yaml +++ b/configs/mssim_vae.yaml @@ -3,21 +3,25 @@ model_params: in_channels: 3 latent_dim: 128 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 30 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "MSSIMVAE" - manual_seed: 1265 diff --git a/configs/swae.yaml b/configs/swae.yaml index 24b7b635..d6427019 100644 --- a/configs/swae.yaml +++ b/configs/swae.yaml @@ -7,24 +7,29 @@ model_params: num_projections: 200 projection_dist: "normal" #"cauchy" +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "SWAE" - manual_seed: 1265 + diff --git a/configs/vae.yaml b/configs/vae.yaml index 77259b36..385ddc17 100644 --- a/configs/vae.yaml +++ b/configs/vae.yaml @@ -3,21 +3,27 @@ model_params: in_channels: 3 latent_dim: 128 + +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: 1 - max_nb_epochs: 50 - max_epochs: 30 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "VanillaVAE" - manual_seed: 1265 + diff --git a/configs/vq_vae.yaml b/configs/vq_vae.yaml index 43163c17..505425f7 100644 --- a/configs/vq_vae.yaml +++ b/configs/vq_vae.yaml @@ -6,21 +6,25 @@ model_params: img_size: 64 beta: 0.25 +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/momo/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number - LR: 0.001 + LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.0 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [2] - max_nb_epochs: 50 - max_epochs: 30 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" - name: "VQVAE" - manual_seed: 1265 + name: 'VQVAE' diff --git a/configs/wae_mmd_imq.yaml b/configs/wae_mmd_imq.yaml index fbf077fb..c9b59409 100644 --- a/configs/wae_mmd_imq.yaml +++ b/configs/wae_mmd_imq.yaml @@ -5,24 +5,29 @@ model_params: reg_weight: 100 kernel_type: 'imq' +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [2] - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "WassersteinVAE_IMQ" - manual_seed: 1265 + diff --git a/configs/wae_mmd_rbf.yaml b/configs/wae_mmd_rbf.yaml index 62a46ec1..e0cb4ec5 100644 --- a/configs/wae_mmd_rbf.yaml +++ b/configs/wae_mmd_rbf.yaml @@ -5,24 +5,29 @@ model_params: reg_weight: 5000 kernel_type: 'rbf' +data_params: + data_path: "Data/" + train_batch_size: 64 + val_batch_size: 64 + patch_size: 64 + num_workers: 4 + + exp_params: - dataset: celeba - data_path: "../../shared/Data/" - img_size: 64 - batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 + kld_weight: 0.00025 + manual_seed: 1265 trainer_params: - gpus: [2] - max_nb_epochs: 50 - max_epochs: 50 + gpus: [1] + max_epochs: 10 logging_params: save_dir: "logs/" name: "WassersteinVAE_RBF" - manual_seed: 1265 + diff --git a/dataset.py b/dataset.py new file mode 100644 index 00000000..9a56fe3e --- /dev/null +++ b/dataset.py @@ -0,0 +1,188 @@ +import os +import torch +from torch import Tensor +from pathlib import Path +from typing import List, Optional, Sequence, Union, Any, Callable +from torchvision.datasets.folder import default_loader +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.datasets import CelebA +import zipfile + + +# class MyDataset(Dataset): +# def __init__(self): +# pass + + +# def __len__(self): +# pass + +# def __getitem__(self, idx): +# pass + +class MyCelebA(CelebA): + """ + Download and Extract + URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing + """ + + def _check_integrity(self) -> bool: + return True + +class Food101(Dataset): + """ + URL : https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ + """ + def __init__(self): + pass + + def __len__(self): + pass + + def __getitem__(self, idx): + pass + +class OxfordPets(Dataset): + """ + URL = https://www.robots.ox.ac.uk/~vgg/data/pets/ + """ + def __init__(self, + data_path: str, + split: str, + transform: Callable, + **kwargs): + self.data_dir = Path(data_path) + self.transforms = transform + imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg']) + + self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):] + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, idx): + img = default_loader(self.imgs[idx]) + + if self.transforms is not None: + img = self.transforms(img) + + return img, 0.0 # dummy datat to prevent breaking + +class VAEDataset(LightningDataModule): + """ + PyTorch Lightning data module + + Args: + data_dir: root directory of your dataset. + train_batch_size: the batch size to use during training. + val_batch_size: the batch size to use during validation. + patch_size: the size of the crop to take from the original images. + num_workers: the number of parallel workers to create to load data + items (see PyTorch's Dataloader documentation for more details). + pin_memory: whether prepared items should be loaded into pinned memory + or not. This can improve performance on GPUs. + """ + + def __init__( + self, + data_path: str, + train_batch_size: int = 8, + val_batch_size: int = 8, + patch_size: Union[int, Sequence[int]] = (256, 256), + num_workers: int = 0, + pin_memory: bool = False, + **kwargs, + ): + super().__init__() + + self.data_dir = data_path + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.patch_size = patch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + + def setup(self, stage: Optional[str] = None) -> None: +# ========================= OxfordPets Dataset ========================= + +# train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), +# transforms.CenterCrop(self.patch_size), +# # transforms.Resize(self.patch_size), +# transforms.ToTensor(), +# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) + +# val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), +# transforms.CenterCrop(self.patch_size), +# # transforms.Resize(self.patch_size), +# transforms.ToTensor(), +# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) + +# self.train_dataset = OxfordPets( +# self.data_dir, +# split='train', +# transform=train_transforms, +# ) + +# self.val_dataset = OxfordPets( +# self.data_dir, +# split='val', +# transform=val_transforms, +# ) + +# ========================= CelebA Dataset ========================= + + train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), + transforms.CenterCrop(148), + transforms.Resize(self.patch_size), + transforms.ToTensor(),]) + + val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), + transforms.CenterCrop(148), + transforms.Resize(self.patch_size), + transforms.ToTensor(),]) + + self.train_dataset = MyCelebA( + self.data_dir, + split='train', + transform=train_transforms, + download=False, + ) + + # Replace CelebA with your dataset + self.val_dataset = MyCelebA( + self.data_dir, + split='test', + transform=val_transforms, + download=False, + ) +# =============================================================== + + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + shuffle=True, + pin_memory=self.pin_memory, + ) + + def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=self.pin_memory, + ) + + def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + return DataLoader( + self.val_dataset, + batch_size=144, + num_workers=self.num_workers, + shuffle=True, + pin_memory=self.pin_memory, + ) \ No newline at end of file diff --git a/experiment.py b/experiment.py index c5b1ba5f..8763a006 100644 --- a/experiment.py +++ b/experiment.py @@ -1,3 +1,4 @@ +import os import math import torch from torch import optim @@ -36,13 +37,13 @@ def training_step(self, batch, batch_idx, optimizer_idx = 0): results = self.forward(real_img, labels = labels) train_loss = self.model.loss_function(*results, - M_N = self.params['batch_size']/ self.num_train_imgs, + M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs, optimizer_idx=optimizer_idx, batch_idx = batch_idx) - self.logger.experiment.log({key: val.item() for key, val in train_loss.items()}) + self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True) - return train_loss + return train_loss['loss'] def validation_step(self, batch, batch_idx, optimizer_idx = 0): real_img, labels = batch @@ -50,52 +51,44 @@ def validation_step(self, batch, batch_idx, optimizer_idx = 0): results = self.forward(real_img, labels = labels) val_loss = self.model.loss_function(*results, - M_N = self.params['batch_size']/ self.num_val_imgs, + M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs, optimizer_idx = optimizer_idx, batch_idx = batch_idx) - return val_loss + self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True) - def validation_end(self, outputs): - avg_loss = torch.stack([x['loss'] for x in outputs]).mean() - tensorboard_logs = {'avg_val_loss': avg_loss} + + def on_validation_end(self) -> None: self.sample_images() - return {'val_loss': avg_loss, 'log': tensorboard_logs} - + def sample_images(self): - # Get sample reconstruction image - test_input, test_label = next(iter(self.sample_dataloader)) + # Get sample reconstruction image + test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader())) test_input = test_input.to(self.curr_device) test_label = test_label.to(self.curr_device) + +# test_input, test_label = batch recons = self.model.generate(test_input, labels = test_label) vutils.save_image(recons.data, - f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" - f"recons_{self.logger.name}_{self.current_epoch}.png", + os.path.join(self.logger.log_dir , + "Reconstructions", + f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"), normalize=True, nrow=12) - # vutils.save_image(test_input.data, - # f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" - # f"real_img_{self.logger.name}_{self.current_epoch}.png", - # normalize=True, - # nrow=12) - try: samples = self.model.sample(144, self.curr_device, labels = test_label) vutils.save_image(samples.cpu().data, - f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/" - f"{self.logger.name}_{self.current_epoch}.png", + os.path.join(self.logger.log_dir , + "Samples", + f"{self.logger.name}_Epoch_{self.current_epoch}.png"), normalize=True, nrow=12) - except: + except Warning: pass - - del test_input, recons #, samples - - def configure_optimizers(self): optims = [] @@ -131,55 +124,3 @@ def configure_optimizers(self): return optims, scheds except: return optims - - @data_loader - def train_dataloader(self): - transform = self.data_transforms() - - if self.params['dataset'] == 'celeba': - dataset = CelebA(root = self.params['data_path'], - split = "train", - transform=transform, - download=False) - else: - raise ValueError('Undefined dataset type') - - self.num_train_imgs = len(dataset) - return DataLoader(dataset, - batch_size= self.params['batch_size'], - shuffle = True, - drop_last=True) - - @data_loader - def val_dataloader(self): - transform = self.data_transforms() - - if self.params['dataset'] == 'celeba': - self.sample_dataloader = DataLoader(CelebA(root = self.params['data_path'], - split = "test", - transform=transform, - download=False), - batch_size= 144, - shuffle = True, - drop_last=True) - self.num_val_imgs = len(self.sample_dataloader) - else: - raise ValueError('Undefined dataset type') - - return self.sample_dataloader - - def data_transforms(self): - - SetRange = transforms.Lambda(lambda X: 2 * X - 1.) - SetScale = transforms.Lambda(lambda X: X/X.sum(0).expand_as(X)) - - if self.params['dataset'] == 'celeba': - transform = transforms.Compose([transforms.RandomHorizontalFlip(), - transforms.CenterCrop(148), - transforms.Resize(self.params['img_size']), - transforms.ToTensor(), - SetRange]) - else: - raise ValueError('Undefined dataset type') - return transform - diff --git a/models/base.py b/models/base.py index ddca2710..86be9aba 100644 --- a/models/base.py +++ b/models/base.py @@ -14,7 +14,7 @@ def decode(self, input: Tensor) -> Any: raise NotImplementedError def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: - raise RuntimeWarning() + raise NotImplementedError def generate(self, x: Tensor, **kwargs) -> Tensor: raise NotImplementedError diff --git a/models/dip_vae.py b/models/dip_vae.py index a19c9e45..e88cf067 100644 --- a/models/dip_vae.py +++ b/models/dip_vae.py @@ -137,7 +137,7 @@ def loss_function(self, mu = args[2] log_var = args[3] - kld_weight = 1 #* kwargs['M_N'] # Account for the minibatch samples from the dataset + kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input, reduction='sum') diff --git a/models/vanilla_vae.py b/models/vanilla_vae.py index 568b63b4..768d25b5 100644 --- a/models/vanilla_vae.py +++ b/models/vanilla_vae.py @@ -143,7 +143,7 @@ def loss_function(self, kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + kld_weight * kld_loss - return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} + return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()} def sample(self, num_samples:int, diff --git a/requirements.txt b/requirements.txt index 8cf7b1ba..31a2ab04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,6 @@ -pytorch-lightning==0.6.0 -PyYAML==5.1.2 +pytorch-lightning==1.5.6 +PyYAML==6.0 tensorboard==2.1.0 -tensorboardX==1.6 -terminado==0.8.1 -test-tube==0.7.0 -torch==1.2.0 -torchfile==0.1.0 -torchnet==0.0.4 +torch>=1.6.1 torchsummary==1.5.1 -torchvision==0.4.0 +torchvision>=0.11.2 diff --git a/run.py b/run.py index d95dd657..370debee 100644 --- a/run.py +++ b/run.py @@ -1,12 +1,17 @@ +import os import yaml import argparse import numpy as np - +from pathlib import Path from models import * from experiment import VAEXperiment import torch.backends.cudnn as cudnn from pytorch_lightning import Trainer -from pytorch_lightning.logging import TestTubeLogger +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from dataset import VAEDataset +from pytorch_lightning.plugins import DDPPlugin parser = argparse.ArgumentParser(description='Generic runner for VAE models') @@ -24,32 +29,33 @@ print(exc) -tt_logger = TestTubeLogger( - save_dir=config['logging_params']['save_dir'], - name=config['logging_params']['name'], - debug=False, - create_git_tag=False, -) + +tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'], + name=config['model_params']['name'],) # For reproducibility -torch.manual_seed(config['logging_params']['manual_seed']) -np.random.seed(config['logging_params']['manual_seed']) -cudnn.deterministic = True -cudnn.benchmark = False +seed_everything(config['exp_params']['manual_seed'], True) model = vae_models[config['model_params']['name']](**config['model_params']) experiment = VAEXperiment(model, config['exp_params']) -runner = Trainer(default_save_path=f"{tt_logger.save_dir}", - min_nb_epochs=1, - logger=tt_logger, - log_save_interval=100, - train_percent_check=1., - val_percent_check=1., - num_sanity_val_steps=5, - early_stop_callback = False, +data = VAEDataset(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0) + +runner = Trainer(logger=tb_logger, + callbacks=[ + LearningRateMonitor(), + ModelCheckpoint(save_top_k=2, + dirpath =os.path.join(tb_logger.log_dir , "checkpoints"), + monitor= "val_loss", + save_last= True), + ], + strategy=DDPPlugin(find_unused_parameters=False), **config['trainer_params']) + +Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True) +Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True) + print(f"======= Training {config['model_params']['name']} =======") -runner.fit(experiment) \ No newline at end of file +runner.fit(experiment, datamodule=data) \ No newline at end of file