Skip to content

Commit

Permalink
Updated results
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Mar 2, 2020
1 parent 4510762 commit f1891e1
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions .idea/PyTorch-VAE.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ doesn't seem to work well.
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png
[35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png
[36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BDIPVAE_83.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file added assets/DIPVAE_83.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_DIPVAE_83.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions configs/dip_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ model_params:
name: 'DIPVAE'
in_channels: 3
latent_dim: 128
lambda_diag: 4.
lambda_offdiag: 2.
lambda_diag: 0.05
lambda_offdiag: 0.1


exp_params:
Expand All @@ -13,7 +13,7 @@ exp_params:
batch_size: 144 # Better to have a square number
LR: 0.001
weight_decay: 0.0
# scheduler_gamma: 0.99
scheduler_gamma: 0.97

trainer_params:
gpus: 1
Expand Down
8 changes: 6 additions & 2 deletions models/dip_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,16 @@ def loss_function(self,
recons_loss =F.mse_loss(recons, input, reduction='sum')


kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

# DIP Loss
centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]
cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]
cov_z = cov_mu + torch.mean(torch.diagonal((2 * log_var).exp(), dim1 = 0), dim = 0) # [D x D]

# Add Variance for DIP Loss II
cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]
# For DIp Loss I
# cov_z = cov_mu

cov_diag = torch.diag(cov_z) # [D]
cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]
Expand Down

0 comments on commit f1891e1

Please sign in to comment.