Skip to content

Commit

Permalink
Added DIP VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Feb 26, 2020
1 parent 3c8fd61 commit a36d705
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 28 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ $ tensorboard --logdir tf
| LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] |
| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] |
| VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** |
| DIP VAE ([Code][dipvae_code], [Config][dipvae_config]) |[Link](https://arxiv.org/abs/1711.00848) | ![][36] | ![][35] |


<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
Expand All @@ -115,20 +116,19 @@ $ tensorboard --logdir tf
- [x] IWAE
- [x] MIWAE
- [x] WAE-MMD
- [x] Conditional VAE
- [x] Conditional VAE- [ ] PixelVAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [x] InfoVAE
- [x] LogCosh VAE
- [x] SWAE
- [x] VQVAE
- [ ] Beta TC-VAE (in progress)
- [x] Beta TC-VAE
- [ ] DIP VAE (In progress)
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well)
- [ ] Vamp VAE (Doesn't work well)
- [ ] PixelVAE



### Contributing
Expand Down Expand Up @@ -178,6 +178,7 @@ doesn't seem to work well.
[catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py
[infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py
[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
[dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py

[vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml
Expand All @@ -196,6 +197,7 @@ doesn't seem to work well.
[catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml
[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml
[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml
[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml

[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png
[2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png
Expand Down Expand Up @@ -228,8 +230,8 @@ doesn't seem to work well.
[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png
[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_20.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_20.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
4 changes: 2 additions & 2 deletions configs/betatc_vae.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model_params:
name: 'BetaTCVAE'
in_channels: 3
latent_dim: 128
latent_dim: 10
anneal_steps: 10000
alpha: 1.
beta: 6.
Expand All @@ -14,7 +14,7 @@ exp_params:
batch_size: 144 # Better to have a square number
LR: 0.001
weight_decay: 0.0
scheduler_gamma: 0.99
# scheduler_gamma: 0.99

trainer_params:
gpus: 1
Expand Down
4 changes: 1 addition & 3 deletions configs/bhvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ model_params:
in_channels: 3
latent_dim: 128
loss_type: 'H'
gamma: 1000.0
max_capacity: 25
Capacity_max_iter: 10000
beta: 10.

exp_params:
dataset: celeba
Expand Down
26 changes: 26 additions & 0 deletions configs/dip_vae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
model_params:
name: 'DIPVAE'
in_channels: 3
latent_dim: 128
lambda_diag: 10.
lambda_offdiag: 5.


exp_params:
dataset: celeba
data_path: "../../shared/momo/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.001
weight_decay: 0.0
# scheduler_gamma: 0.99

trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 50

logging_params:
save_dir: "logs/"
name: "DIPVAE"
manual_seed: 1265
2 changes: 2 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .miwae import *
from .vq_vae import *
from .betatc_vae import *
from .dip_vae import *


# Aliases
Expand All @@ -35,6 +36,7 @@
'MIWAE':MIWAE,
'VQVAE':VQVAE,
'DFCVAE':DFCVAE,
'DIPVAE':DIPVAE,
'BetaVAE':BetaVAE,
'InfoVAE':InfoVAE,
'WAE_MMD':WAE_MMD,
Expand Down
35 changes: 18 additions & 17 deletions models/betatc_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self,

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
hidden_dims = [32, 32, 32, 32]

# Build Encoder
for h_dim in hidden_dims:
Expand All @@ -42,14 +42,16 @@ def __init__(self,
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)

self.fc = nn.Linear(hidden_dims[-1]*16, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_var = nn.Linear(256, latent_dim)


# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
self.decoder_input = nn.Linear(latent_dim, 256 * 2)

hidden_dims.reverse()

Expand All @@ -65,8 +67,6 @@ def __init__(self,
nn.LeakyReLU())
)



self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
Expand All @@ -89,8 +89,9 @@ def encode(self, input: Tensor) -> List[Tensor]:
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

result = torch.flatten(result, start_dim=1)
result = self.fc(result)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
Expand All @@ -106,7 +107,7 @@ def decode(self, z: Tensor) -> Tensor:
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = result.view(-1, 32, 4, 4)
result = self.decoder(result)
result = self.final_layer(result)
return result
Expand All @@ -131,9 +132,9 @@ def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
"""
Computes the log pdf of the Gaussian with parameters mu and logvar at x
:param x:
:param mu:
:param logvar:
:param x: (Tensor) Point at whichGaussian PDF is to be evaluated
:param mu: (Tensor) Mean of the Gaussian distribution
:param logvar: (Tensor) Log variance of the Gaussian distribution
:return:
"""
norm = - 0.5 * (math.log(2 * math.pi) + logvar)
Expand All @@ -159,7 +160,7 @@ def loss_function(self,

weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset

recons_loss =F.mse_loss(recons, input)
recons_loss =F.mse_loss(recons, input, reduction='sum')

log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)

Expand All @@ -174,21 +175,21 @@ def loss_function(self,
# Reference
# [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54
dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
strat_weight = (dataset_size - batch_size - 1) / (dataset_size * (batch_size - 1))
strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
importance_weights.view(-1)[::batch_size] = 1 / dataset_size
importance_weights.view(-1)[1::batch_size] = strat_weight
importance_weights[batch_size - 2, 0] = strat_weight
log_importance_weights = importance_weights.log()

mat_log_q_z += log_importance_weights.unsqueeze(2)
mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)

log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)

kld_loss = (log_prod_q_z - log_p_z).mean()
tc_loss = (log_q_z - log_prod_q_z).mean()
mi_loss = (log_q_zx - log_q_z).mean()
tc_loss = (log_q_z - log_prod_q_z).mean()
kld_loss = (log_prod_q_z - log_p_z).mean()

# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

Expand All @@ -198,7 +199,7 @@ def loss_function(self,
else:
anneal_rate = 1.

loss = recons_loss + \
loss = recons_loss/batch_size + \
self.alpha * mi_loss + \
weight * (self.beta * tc_loss +
anneal_rate * self.gamma * kld_loss)
Expand Down
Loading

0 comments on commit a36d705

Please sign in to comment.