Skip to content

Commit

Permalink
Added factor VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 22, 2020
1 parent 48830d1 commit 0de6b0b
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 11 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.
### Requirements
- Python >= 3.5
- PyTorch >= 1.3
- Pytorch Lightning >= 0.5.3 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32))
- Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32))
- CUDA enabled computing device

### Installation
Expand All @@ -43,12 +43,18 @@ model_params:
name: "<name of VAE model>"
in_channels: 3
latent_dim:
. # Other parameters required by the model
.
.

exp_params:
data_path: "<path to the celebA dataset>"
img_size: 64 # Models are designed to work for this size
batch_size: 64 # Better to have a square number
LR: 0.005
. # Other arguments required for training like scheduler etc.
.
.

trainer_params:
gpus: 1
Expand Down
24 changes: 24 additions & 0 deletions configs/factorvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
model_params:
name: 'FactorVAE'
in_channels: 3
latent_dim: 128
gamma: 40

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95
LR_2: 0.005
scheduler_gamma_2: 0.95

trainer_params:
gpus: [2]
max_nb_epochs: 30


logging_params:
save_dir: "logs/"
name: "FactorVAE"
manual_seed: 1265
21 changes: 20 additions & 1 deletion experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,20 @@ def training_step(self, batch, batch_idx, optimizer_idx = 0):

results = self.forward(real_img, labels = labels)

real_img2 = None

try:
# Required for factor VAE
if self.params['require_secondary_input']:
real_img2,_ = next(iter(self.sample_dataloader))
real_img2 = real_img.to(self.curr_device)
except:
pass

train_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs,
optimizer_idx = optimizer_idx)
optimizer_idx=optimizer_idx,
secondary_input = real_img2)

self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})

Expand Down Expand Up @@ -75,6 +86,14 @@ def sample_images(self):
nrow=int(math.sqrt(self.params['batch_size'])))
del test_input, recons, samples

def on_before_backward(self, loss, optimizer_idx):
# example to retrain graph for this optimizer
opt = {'loss': loss,
'skip_backward': False,
'retain_graph': False}
if optimizer_idx < 1:
opt['retain_graph'] = True
return opt

def configure_optimizers(self):

Expand Down
34 changes: 26 additions & 8 deletions models/fvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
gamma: float = 40.,
**kwargs) -> None:
super(FactorVAE, self).__init__()

self.latent_dim = latent_dim
self.gamma = gamma

modules = []
if hidden_dims is None:
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(self,
kernel_size= 3, padding= 1),
nn.Tanh())

# Discriminator network for the Total Correlation (TC) loss
self.discrminator = nn.Sequential(nn.Linear(self.latent_dim, 1000),
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.2),
Expand All @@ -83,6 +86,8 @@ def __init__(self,
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.2),
nn.Linear(1000, 2))
self.D_z_reserve = None


def encode(self, input: Tensor) -> List[Tensor]:
"""
Expand Down Expand Up @@ -129,7 +134,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
return [self.decode(z), input, mu, log_var, z]

def permute_latent(self, z: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -157,27 +162,40 @@ def loss_function(self,
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
optimizer_idx = kwargs['optimizer_idx']

# Update the VAE
if optimizer_idx == 0:
recons_loss =F.mse_loss(recons, input)


kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss,
'Reconstruction Loss':recons_loss,
'KLD':-kld_loss}
self.D_z_reserve = self.discrminator(z)
vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()

loss = recons_loss + kld_weight * kld_loss - self.gamma * vae_tc_loss
return {'loss': loss} #,
# 'Reconstruction Loss':recons_loss,
# 'KLD':-kld_loss,
# 'VAE_TC Loss': vae_tc_loss}

# Update the Discriminator
elif optimizer_idx == 1:
pass
true_labels = torch.ones(input.size(0), dtype= torch.long, requires_grad=False)
false_labels = torch.zeros(input.size(0), dtype= torch.long, requires_grad=False)

real_img2 = kwargs['secondary_input']

result = self.forward(real_img2)
z2 = result[4].detach() # Detach so that VAE is not trained again
z2_perm = self.permute_latent(z2)
D_z2_perm = self.discrminator(z2_perm)
D_tc_loss = -0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
F.cross_entropy(D_z2_perm, true_labels))

return {'loss': D_tc_loss}

def sample(self,
num_samples:int,
Expand Down
11 changes: 11 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pytorch-lightning==0.6.0
PyYAML==5.1.2
tensorboard==2.1.0
tensorboardX==1.6
terminado==0.8.1
test-tube==0.7.0
torch==1.2.0
torchfile==0.1.0
torchnet==0.0.4
torchsummary==1.5.1
torchvision==0.4.0
2 changes: 1 addition & 1 deletion tests/test_dfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchsummary import summary


class TestIWAE(unittest.TestCase):
class TestDFCVAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_fvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import unittest
from models import FactorVAE
from torchsummary import summary


class TestFAE(unittest.TestCase):

def setUp(self) -> None:
# self.model2 = VAE(3, 10)
self.model = FactorVAE(3, 10)

def test_summary(self):
print(summary(self.model, (3, 64, 64), device='cpu'))
#
# print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))

# print(summary(self.model2, (3, 64, 64), device='cpu'))

def test_forward(self):
x = torch.randn(16, 3, 64, 64)
y = self.model(x)
print("Model Output size:", y[0].size())

# print("Model2 Output size:", self.model2(x)[0].size())

def test_loss(self):
x = torch.randn(16, 3, 64, 64)
x2 = torch.randn(16,3, 64, 64)

result = self.model(x)
loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=0, secondary_input=x2)
loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=1, secondary_input=x2)
print(loss)

def test_optim(self):
optim1 = torch.optim.Adam(self.model.parameters(), lr = 0.001)
optim2 = torch.optim.Adam(self.model.discrminator.parameters(), lr = 0.001)

def test_sample(self):
self.model.cuda()
y = self.model.sample(144, 0)




if __name__ == '__main__':
unittest.main()

0 comments on commit 0de6b0b

Please sign in to comment.