Skip to content

Commit

Permalink
Updated Factor VAE
Browse files Browse the repository at this point in the history
Removed cifar10 dataset
  • Loading branch information
AntixK committed Jan 29, 2020
1 parent d8ebc48 commit cef261c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 66 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ logging_params:
- [ ] FactorVAE (in progress)
- [ ] Beta TC-VAE
- [ ] TwoStageVAE
- [ ] VAE-GAN
- [ ] VLAE
- [ ] PixelVAE
- [ ] VQVAE
- [ ] StyleVAE
- [ ] Sequential VAE
Expand Down
3 changes: 2 additions & 1 deletion configs/factorvae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ exp_params:
data_path: "../../shared/Data/"
submodel: 'discriminator'
require_secondary_input: True
retain_first_backpass: True
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
Expand All @@ -18,7 +19,7 @@ exp_params:
scheduler_gamma_2: 0.95

trainer_params:
gpus: [2]
gpus: [3]
max_nb_epochs: 30
max_epochs: 50

Expand Down
58 changes: 11 additions & 47 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA, CIFAR10
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader



class VAEXperiment(pl.LightningModule):
RETAIN_GRAPH = True

def __init__(self,
vae_model: BaseVAE,
Expand All @@ -22,6 +21,11 @@ def __init__(self,
self.model = vae_model
self.params = params
self.curr_device = None
self.hold_graph = False
try:
self.hold_graph = self.params['retain_first_backpass']
except:
pass

def forward(self, input: Tensor, **kwargs) -> Tensor:
return self.model(input, **kwargs)
Expand Down Expand Up @@ -99,11 +103,11 @@ def sample_images(self):

del test_input, recons, samples

# def backward(self, use_amp, loss, optimizer):
# print('called during backward')
#
# loss.backward(retain_graph = self.RETAIN_GRAPH)
# RETAIN_GRAP
def backward(self, use_amp, loss, optimizer, optimizer_idx):
if self.hold_graph and optimizer_idx == 0:
loss.backward(retain_graph = True)
else:
loss.backward(retain_graph = False)

def configure_optimizers(self):

Expand Down Expand Up @@ -150,14 +154,6 @@ def train_dataloader(self):
split = "train",
transform=transform,
download=False)
# Required for Categorical VAE
elif self.params['dataset'] == 'cifar10':
target_transforms = self.target_transforms()
dataset = CIFAR10(root = self.params['data_path'],
train = True,
transform=transform,
target_transform=target_transforms,
download=False)
else:
raise ValueError('Undefined dataset type')

Expand All @@ -179,17 +175,6 @@ def val_dataloader(self):
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)

elif self.params['dataset'] == 'cifar10':
target_transforms = self.target_transforms()
self.sample_dataloader = DataLoader(CIFAR10(root = self.params['data_path'],
train = False,
transform=transform,
target_transform=target_transforms,
download=False),
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)
else:
raise ValueError('Undefined dataset type')
return self.sample_dataloader
Expand All @@ -204,28 +189,7 @@ def data_transforms(self):
transforms.Resize(self.params['img_size']),
transforms.ToTensor(),
SetRange])

elif self.params['dataset'] == 'cifar10':
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda img:
torch.nn.functional.upsample_bilinear(
img.unsqueeze(0),
self.params['img_size']).squeeze()
),
SetRange])
else:
raise ValueError('Undefined dataset type')
return transform

def target_transforms(self):
transform = transforms.Compose([transforms.Lambda(lambda labels:
torch.zeros(1, 10).scatter_(1,
torch.tensor(labels).view(-1, 1),
1)
)
])

# return transform


30 changes: 14 additions & 16 deletions models/fvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,32 +175,30 @@ def loss_function(self,
self.D_z_reserve = self.discriminator(z)
vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()

loss = recons_loss + kld_weight * kld_loss - self.gamma * vae_tc_loss
loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss

# print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}')
return {'loss': loss} #,
# 'Reconstruction Loss':recons_loss,
# 'KLD':-kld_loss,
# 'VAE_TC Loss': vae_tc_loss}
return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'KLD':-kld_loss,
'VAE_TC_Loss': vae_tc_loss}

# Update the Discriminator
elif optimizer_idx == 1:

device = input.device
true_labels = torch.ones(input.size(0), dtype= torch.long,
requires_grad=False).to(device)
false_labels = torch.zeros(input.size(0), dtype= torch.long,
requires_grad=False).to(device)

real_img2 = kwargs['secondary_input']

result = self.forward(real_img2)
z2 = result[4].detach() # Detach so that VAE is not trained again
z2_perm = self.permute_latent(z2)
D_z2_perm = self.discriminator(z2_perm)
D_tc_loss = -0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
F.cross_entropy(D_z2_perm, true_labels))
print(f'D_TC: {D_tc_loss}')
return {'loss': D_tc_loss}
z = z.detach() # Detach so that VAE is not trained again
z_perm = self.permute_latent(z)
D_z_perm = self.discriminator(z_perm)
D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
F.cross_entropy(D_z_perm, true_labels))
# print(f'D_TC: {D_tc_loss}')
return {'loss': D_tc_loss,
'D_TC_Loss':D_tc_loss}

def sample(self,
num_samples:int,
Expand Down

0 comments on commit cef261c

Please sign in to comment.