-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Laknath1996
committed
Mar 25, 2024
0 parents
commit fde3b4d
Showing
10 changed files
with
367 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
wandb | ||
output | ||
data |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
deploy: True | ||
device: "cuda:1" | ||
seed: 1996 | ||
dim: 32 | ||
image_size: 32 | ||
lr: 0.0001 | ||
batch_size: 64 | ||
iters: 100000 | ||
n_critic: 5 | ||
gp: False | ||
workers: 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
|
||
self.main = nn.Sequential( | ||
nn.Conv2d(1, dim, 3, bias=False), | ||
nn.LeakyReLU(0.2, True), | ||
|
||
nn.Conv2d(dim, 2*dim, 3, bias=False), | ||
nn.BatchNorm2d(2*dim), | ||
nn.LeakyReLU(0.2, True), | ||
|
||
nn.Conv2d(2*dim, 4*dim, 3, bias=False), | ||
nn.BatchNorm2d(4*dim), | ||
nn.LeakyReLU(0.2, True), | ||
|
||
nn.Conv2d(4*dim, 8*dim, 3, bias=False), | ||
nn.BatchNorm2d(8*dim), | ||
nn.LeakyReLU(0.2, True), | ||
|
||
nn.Conv2d(8*dim, 1, 3), | ||
) | ||
|
||
def forward(self, input): | ||
out = self.main(input) | ||
out = torch.flatten(out) | ||
return out | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
|
||
self.main = nn.Sequential( | ||
nn.ConvTranspose2d(100, 8*dim, 2, 1, 0, bias=False), | ||
nn.BatchNorm2d(8*dim), | ||
nn.ReLU(True), | ||
|
||
nn.ConvTranspose2d(8*dim, 4*dim, 3, 2, 1, 1, bias=False), | ||
nn.BatchNorm2d(4*dim), | ||
nn.ReLU(True), | ||
|
||
nn.ConvTranspose2d(4*dim, 2*dim, 3, 2, 1, 1, bias=False), | ||
nn.BatchNorm2d(2*dim), | ||
nn.ReLU(True), | ||
|
||
nn.ConvTranspose2d(2*dim, dim, 3, 2, 1, 1, bias=False), | ||
nn.BatchNorm2d(dim), | ||
nn.ReLU(True), | ||
|
||
nn.ConvTranspose2d(dim, 1, 3, 2, 1, 1), | ||
nn.Tanh() | ||
) | ||
|
||
def forward(self, input): | ||
out = self.main(input) | ||
return out | ||
|
||
|
||
def discriminator(dim): | ||
model = Discriminator(dim) | ||
return model | ||
|
||
def generator(dim): | ||
model = Generator(dim) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from omegaconf import OmegaConf\n", | ||
"from trainer import Trainer\n", | ||
"import wandb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"args = OmegaConf.load('config.yaml')\n", | ||
"trainer = Trainer(args)\n", | ||
"# trainer.run()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "prol", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from omegaconf import OmegaConf | ||
from trainer import Trainer | ||
|
||
args = OmegaConf.load('config.yaml') | ||
trainer = Trainer(args) | ||
trainer.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import os | ||
import math | ||
import wandb | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
import torchvision.utils as vutils | ||
|
||
from tqdm import tqdm | ||
|
||
from model import discriminator, generator | ||
from utils import weights_init, init_torch_seeds, gradient_penalty, init_wandb | ||
|
||
class Trainer(): | ||
def __init__(self, args): | ||
self.args = args | ||
dataset = torchvision.datasets.MNIST( | ||
root='../data', | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose([ | ||
transforms.Resize((args.image_size, args.image_size)), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.50], std=[0.25]) | ||
]) | ||
) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=args.batch_size, | ||
pin_memory=True, | ||
num_workers=int(args.workers), | ||
shuffle=True | ||
) | ||
self.device = torch.device(args.device if torch.cuda.is_available() else "cpu") | ||
|
||
self.discriminator = discriminator(args.dim).to(self.device) | ||
self.generator = generator(args.dim).to(self.device) | ||
|
||
self.discriminator = self.discriminator.apply(weights_init) | ||
self.generator = self.generator.apply(weights_init) | ||
|
||
self.optimizer_d = torch.optim.RMSprop(self.discriminator.parameters(), lr=args.lr) | ||
self.optimizer_g = torch.optim.RMSprop(self.generator.parameters(), lr=args.lr) | ||
|
||
def get_infinite_batches(self, dataloader): | ||
while True: | ||
for data, _ in dataloader: | ||
yield data | ||
|
||
def run(self): | ||
args = self.args | ||
|
||
# logging | ||
init_wandb(args, project_name='wgan') | ||
|
||
# Set random initialization seed, easy to reproduce. | ||
init_torch_seeds(args.seed) | ||
|
||
self.discriminator.train() | ||
self.generator.train() | ||
|
||
# fixed noise to generate the same set of images | ||
fixed_noise = torch.randn(args.batch_size, 100, 1, 1, device=self.device) | ||
|
||
# infinite data iterator | ||
self.dataiterator = self.get_infinite_batches(self.dataloader) | ||
|
||
for it in range(args.iters): | ||
|
||
# compute gradients for discriminator | ||
for p in self.discriminator.parameters(): | ||
p.requires_grad = True | ||
|
||
for _ in range(args.n_critic): | ||
|
||
data = next(self.dataiterator) | ||
real_images = data.to(self.device) | ||
batch_size = real_images.size(0) | ||
noise = torch.randn(batch_size, 100, 1, 1, device=self.device) | ||
|
||
# (1) Update D network: | ||
|
||
# Set D gradients to zero. | ||
self.discriminator.zero_grad() | ||
|
||
# clip parameters of D to a range [-c, c] | ||
for p in self.discriminator.parameters(): | ||
p.data.clamp_(-0.01, 0.01) | ||
|
||
# Pass real images through D | ||
real_output = self.discriminator(real_images) | ||
errD_real = torch.mean(real_output) | ||
|
||
# Generate fake image batch with G | ||
fake_images = self.generator(noise) | ||
|
||
# Pass fake images through D | ||
fake_output = self.discriminator(fake_images.detach()) | ||
errD_fake = torch.mean(fake_output) | ||
|
||
# compute the D loss | ||
errD = -errD_real + errD_fake | ||
|
||
# compute the D loss gradients | ||
errD.backward() | ||
|
||
# Update D | ||
self.optimizer_d.step() | ||
|
||
# (2) Update G network: | ||
|
||
for p in self.discriminator.parameters(): | ||
p.requires_grad = False # to avoid computation | ||
|
||
# Set generator gradients to zero | ||
self.generator.zero_grad() | ||
|
||
# Generate fake image batch with G | ||
noise = torch.randn(batch_size, 100, 1, 1, device=self.device) | ||
fake_images = self.generator(noise) | ||
fake_output = self.discriminator(fake_images) | ||
errG = -torch.mean(fake_output) | ||
errG.backward() | ||
self.optimizer_g.step() | ||
|
||
info = { | ||
'iter': it + 1, | ||
'Loss_D': np.round(errD.item(), 4), | ||
'Loss_G': np.round(errG.item(), 4) | ||
} | ||
print(info) | ||
|
||
if args.deploy: | ||
wandb.log(info) | ||
|
||
# The image is saved every 1000 epoch. | ||
if (it+1) % 1000 == 0: | ||
vutils.save_image(real_images, | ||
os.path.join("output", "real_samples.png"), | ||
normalize=True) | ||
fake = self.generator(fixed_noise) | ||
vutils.save_image(fake.detach(), | ||
os.path.join("output", f"fake_samples_{it}.png"), | ||
normalize=True) | ||
|
||
wandb.finish() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
import wandb | ||
from omegaconf import OmegaConf | ||
|
||
def init_torch_seeds(seed: int = 0): | ||
r""" Sets the seed for generating random numbers. Returns a | ||
Args: | ||
seed (int): The desired seed. | ||
""" | ||
|
||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html | ||
if seed == 0: # slower, more reproducible | ||
cudnn.deterministic = True | ||
cudnn.benchmark = False | ||
else: # faster, less reproducible | ||
cudnn.deterministic = False | ||
cudnn.benchmark = True | ||
|
||
np.random.seed(seed) | ||
random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
# custom weights initialization called on netG and netD | ||
def weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find("Conv") != -1: | ||
torch.nn.init.normal_(m.weight, 0.0, 0.02) | ||
elif classname.find("BatchNorm") != -1: | ||
torch.nn.init.normal_(m.weight, 1.0, 0.02) | ||
torch.nn.init.zeros_(m.bias) | ||
|
||
def init_wandb(args, project_name): | ||
if args.deploy: | ||
wandb.init(project=project_name) | ||
wandb.run.name = wandb.run.id | ||
wandb.run.save() | ||
wandb.config.update(OmegaConf.to_container(args)) | ||
|
||
def gradient_penalty(model, real_images, fake_images, device): | ||
"""Calculates the gradient penalty loss for WGAN GP""" | ||
# Random weight term for interpolation between real and fake data | ||
alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device) | ||
# Get random interpolation between real and fake data | ||
interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True) | ||
|
||
model_interpolates = model(interpolates) | ||
grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False) | ||
|
||
# Get gradient w.r.t. interpolates | ||
gradients = torch.autograd.grad( | ||
outputs=model_interpolates, | ||
inputs=interpolates, | ||
grad_outputs=grad_outputs, | ||
create_graph=True, | ||
retain_graph=True, | ||
only_inputs=True, | ||
)[0] | ||
gradients = gradients.view(gradients.size(0), -1) | ||
gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2) | ||
return gradient_penalty |