forked from AntixK/PyTorch-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment.py
125 lines (98 loc) · 4.64 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import math
import torch
from torch import optim
from models import BaseVAE
from models.types_ import *
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
class VAEXperiment(pl.LightningModule):
def __init__(self,
vae_model: BaseVAE,
params: dict) -> None:
super(VAEXperiment, self).__init__()
self.model = vae_model
self.params = params
self.curr_device = None
def forward(self, input: Tensor, **kwargs) -> Tensor:
return self.model(input, **kwargs)
def training_step(self, batch, batch_idx):
real_img, labels = batch
self.curr_device = real_img.device
results = self.forward(real_img, labels = labels)
train_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs )
self.logger.experiment.log({key: val.item() for key, val in train_loss.items()})
return train_loss
def validation_step(self, batch, batch_idx):
real_img, labels = batch
self.curr_device = real_img.device
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs)
return val_loss
def validation_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
tensorboard_logs = {'avg_val_loss': avg_loss}
self.sample_images()
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def sample_images(self):
# # samples = self.model.sample(self.params['batch_size'], self.curr_device).cpu()
# z = torch.randn(self.params['batch_size'],
# self.model.latent_dim)
#
# if self.on_gpu:
# z = z.cuda(self.curr_device)
#
# samples = self.model.decode(z).cpu()
#
# vutils.save_image(samples.data,
# f"{self.logger.save_dir}/{self.logger.name}/sample_{self.current_epoch}.png",
# normalize=True,
# nrow=int(math.sqrt(self.params['batch_size'])))
# Get sample reconstruction image
test_input, _ = next(iter(self.sample_dataloader))
test_input = test_input.cuda(self.curr_device)
recons = self.model(test_input)
vutils.save_image(recons[0].data,
f"{self.logger.save_dir}/{self.logger.name}/recons_{self.current_epoch}.png",
normalize=True,
nrow=int(math.sqrt(self.params['batch_size'])))
del test_input, recons #, samples, z
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR'])
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.params['scheduler_gamma'])
return [optimizer] #, [scheduler]
@pl.data_loader
def train_dataloader(self):
transform = self.data_transforms()
dataset = CelebA(root = self.params['data_path'],
split = "train",
transform=transform,
download=False)
self.num_train_imgs = len(dataset)
return DataLoader(dataset,
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)
@pl.data_loader
def val_dataloader(self):
transform = self.data_transforms()
self.sample_dataloader = DataLoader(CelebA(root = self.params['data_path'],
split = "test",
transform=transform,
download=False),
batch_size= self.params['batch_size'],
shuffle = True,
drop_last=True)
return self.sample_dataloader
def data_transforms(self):
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)
transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(self.params['img_size']),
transforms.ToTensor(),
SetRange])
return transform