Skip to content

Commit

Permalink
Updated GammaVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Jan 24, 2020
1 parent 8ad5855 commit 2946d7a
Show file tree
Hide file tree
Showing 35 changed files with 149 additions and 53 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ exp_params:
img_size: 64 # Models are designed to work for this size
batch_size: 64 # Better to have a square number
LR: 0.005
. # Other arguments required for training like scheduler etc.
weight_decay:
. # Other arguments required for training, like scheduler etc.
.
.

trainer_params:
gpus: 1
max_nb_epochs: 50
gradient_clip_val: 0.005
.
.
.

logging_params:
save_dir: "logs/"
Expand Down
1 change: 1 addition & 0 deletions configs/bhvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.0005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/bvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.0005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
23 changes: 23 additions & 0 deletions configs/cvae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model_params:
name: 'ConditionalVAE'
in_channels: 3
num_classes: 40
latent_dim: 128

exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 1
max_nb_epochs: 50


logging_params:
save_dir: "logs/"
name: "ConditionalVAE"
manual_seed: 1265
1 change: 1 addition & 0 deletions configs/dfc_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/factorvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
LR_2: 0.005
scheduler_gamma_2: 0.95
Expand Down
6 changes: 4 additions & 2 deletions configs/gammavae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ exp_params:
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
scheduler_gamma: 0.95
LR: 0.003
weight_decay: 0.0005


trainer_params:
gpus: 1
max_nb_epochs: 50
gradient_clip_val: 0.8


logging_params:
Expand Down
1 change: 1 addition & 0 deletions configs/hvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/iwae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.007
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/mssim_vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
3 changes: 2 additions & 1 deletion configs/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
gpus: 0
gpus: 1
max_nb_epochs: 50


Expand Down
1 change: 1 addition & 0 deletions configs/vampvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/wae_mmd_imq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
1 change: 1 addition & 0 deletions configs/wae_mmd_rbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ exp_params:
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95

trainer_params:
Expand Down
27 changes: 17 additions & 10 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,27 @@ def validation_end(self, outputs):
return {'val_loss': avg_loss, 'log': tensorboard_logs}

def sample_images(self):
samples = self.model.sample(self.params['batch_size'], self.curr_device).cpu()
vutils.save_image(samples.data,
f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
f"{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))

# Get sample reconstruction image
test_input, _ = next(iter(self.sample_dataloader))
test_input, test_label = next(iter(self.sample_dataloader))
test_input = test_input.to(self.curr_device)
recons = self.model.generate(test_input)
test_label = test_label.to(self.curr_device)
recons = self.model.generate(test_input, labels = test_label)
vutils.save_image(recons.data,
f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
f"recons_{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))

samples = self.model.sample(self.params['batch_size'],
self.curr_device,
labels = test_label).cpu()
vutils.save_image(samples.data,
f"{self.logger.save_dir}{self.logger.name}/version_{self.logger.version}/"
f"{self.logger.name}_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))


del test_input, recons, samples

# def backward(self, use_amp, loss, optimizer):
Expand All @@ -97,7 +102,9 @@ def configure_optimizers(self):
optims = []
scheds = []

optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR'])
optimizer = optim.Adam(self.model.parameters(),
lr=self.params['LR'],
weight_decay=self.params['weight_decay'])
optims.append(optimizer)
# Check if more than 1 optimizer is required (Used for adversarial training)
try:
Expand Down
Binary file modified models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/base.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/beta_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/gamma_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-36.pyc
Binary file not shown.
Binary file modified models/__pycache__/vanilla_vae.cpython-37.pyc
Binary file not shown.
Binary file modified models/__pycache__/wae_mmd.cpython-36.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion models/beta_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def loss_function(self,
else:
raise ValueError('Undefined loss type.')

return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':kld_loss}
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

def sample(self,
num_samples:int,
Expand Down
15 changes: 9 additions & 6 deletions models/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
y = kwargs['labels']
y = kwargs['labels'].float()
embedded_class = self.embed_class(y)
embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
embedded_input = self.embed_data(input)
Expand All @@ -144,31 +144,34 @@ def loss_function(self,
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

def sample(self,
num_samples:int,
current_device: int) -> Tensor:
current_device: int,
**kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
y = kwargs['labels'].float()
z = torch.randn(num_samples,
self.latent_dim)

z = z.cuda(current_device)
z = z.to(current_device)

z = torch.cat([z, y], dim=1)
samples = self.decode(z)
return samples

def generate(self, x: Tensor) -> Tensor:
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]
return self.forward(x, **kwargs)[0]
2 changes: 1 addition & 1 deletion models/dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def loss_function(self,
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

def sample(self,
num_samples:int,
Expand Down
Loading

0 comments on commit 2946d7a

Please sign in to comment.