Skip to content

Commit

Permalink
Added IWAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 20, 2020
1 parent 8336b21 commit cf58fa4
Show file tree
Hide file tree
Showing 27 changed files with 392 additions and 38 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ $ python run.py -c configs/<config-file-name.yaml>

----

| Model | Paper |Reconstruction | Samples |
|--------------------------|----------------------------------|---------------|---------|
| VAE |https://arxiv.org/abs/1312.6114 | ![][2] | ![][1] |
| WAE - MMD (RBF Kernel) |https://arxiv.org/abs/1711.01558 | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel) |https://arxiv.org/abs/1711.01558 | ![][6] | ![][5] |
| Model | Paper |Reconstruction | Samples |
|-----------------------|--------------------------------------------------|---------------|---------|
| VAE |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] |
| WAE - MMD (RBF Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel)|[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE |[Link](https://arxiv.org/abs/1804.03599) | ![][10] | ![][9] |



Expand All @@ -55,7 +57,8 @@ $ python run.py -c configs/<config-file-name.yaml>
- [ ] AAE
- [ ] TwoStageVAE
- [ ] VAE-GAN
- [x] HVAE (VAE with Vamp Prior)
- [ ] Vamp VAE
- [ ] HVAE (VAE with Vamp Prior)
- [ ] IWAE
- [ ] VLAE
- [ ] FactorVAE
Expand Down
25 changes: 25 additions & 0 deletions configs/bhvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
loss_type: 'H'
gamma: 1000.0
max_capacity: 25
Capacity_max_iter: 10000

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.0005
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "BetaVAE_H"
manual_seed: 1265
23 changes: 23 additions & 0 deletions configs/bvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
beta: 10
loss_type: 'B'

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.0005
scheduler_gamma: 0.95

trainer_params:
gpus: [2]
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "BetaVAE_B"
manual_seed: 1265
21 changes: 21 additions & 0 deletions configs/vampvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
model_params:
name: 'VampVAE'
in_channels: 3
latent_dim: 128

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "VampVAE"
manual_seed: 1265
12 changes: 6 additions & 6 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validation_end(self, outputs):
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def sample_images(self):
# # samples = self.model.sample(self.params['batch_size'], self.curr_device).cpu()
samples = self.model.sample(self.params['batch_size'], self.curr_device).cpu()
# z = torch.randn(self.params['batch_size'],
# self.model.latent_dim)
#
Expand All @@ -64,10 +64,10 @@ def sample_images(self):
#
# samples = self.model.decode(z).cpu()
#
# vutils.save_image(samples.data,
# f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
# normalize=True,
# nrow=int(math.sqrt(self.params['batch_size'])))
vutils.save_image(samples.data,
f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))

# Get sample reconstruction image
test_input, _ = next(iter(self.sample_dataloader))
Expand All @@ -78,7 +78,7 @@ def sample_images(self):
f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
del test_input, recons #, samples, z
del test_input, recons, samples #, samples, z



Expand Down
5 changes: 4 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cvae import *
from .hvae import *
from .vampvae import *
from .iwae import *

# Aliases
VAE = VanillaVAE
Expand All @@ -17,4 +18,6 @@
'ConditionalVAE':ConditionalVAE,
'BetaVAE':BetaVAE,
'GammaVAE':GammaVAE,
'HVAE':HVAE}
'HVAE':HVAE,
'VampVAE':VampVAE,
'IWAE':IWAE}
Binary file modified models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/base.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/types_.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/wae_mmd.cpython-36.pyc
Binary file not shown.
31 changes: 25 additions & 6 deletions models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,26 @@

class BetaVAE(BaseVAE):

num_iter = 0 # Global static variable to keep track of iterations

def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
beta: int = 1,
beta: int = 4,
gamma:float = 1000.,
max_capacity: int = 25,
Capacity_max_iter: int = 1e5,
loss_type:str = 'B',
**kwargs) -> None:
super(BetaVAE, self).__init__()

self.latent_dim = latent_dim
self.beta = beta
self.gamma = gamma
self.loss_type = loss_type
self.C_max = torch.FloatTensor(max_capacity)
self.C_stop_iter = Capacity_max_iter

modules = []
if hidden_dims is None:
Expand Down Expand Up @@ -114,29 +124,38 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
return [self.decode(z), input, mu, log_var]

def loss_function(self,
*args,
**kwargs) -> dict:
self.num_iter += 1
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

recons_loss =F.mse_loss(recons, input)

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + self.beta * kld_loss
if self.loss_type == 'B':
loss = recons_loss + self.beta * kld_weight * kld_loss
elif self.loss_type == 'H':
self.C_max = self.C_max.cuda(input.device)
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
else:
raise ValueError('Undefined loss type.')

return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':kld_loss}

def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)
z = z.cuda(current_device)

samples = self.model.decode(z)
samples = self.decode(z)
return samples
5 changes: 2 additions & 3 deletions models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)
z = z.cuda(current_device)

samples = self.model.decode(z)
samples = self.decode(z)
return samples
5 changes: 2 additions & 3 deletions models/gamma_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def sample(self, batch_size:int, current_device: int) -> Tensor:
z = torch.randn(batch_size,
self.latent_dim)

if self.on_gpu:
z = z.cuda(current_device)
z = z.cuda(current_device)

samples = self.model.decode(z)
samples = self.decode(z)
return samples
Loading

0 comments on commit cf58fa4

Please sign in to comment.