Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Laknath1996 committed Mar 25, 2024
0 parents commit fde3b4d
Show file tree
Hide file tree
Showing 10 changed files with 367 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wandb
output
data
Binary file added wgan/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file added wgan/__pycache__/trainer.cpython-39.pyc
Binary file not shown.
Binary file added wgan/__pycache__/utils.cpython-39.pyc
Binary file not shown.
11 changes: 11 additions & 0 deletions wgan/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
deploy: True
device: "cuda:1"
seed: 1996
dim: 32
image_size: 32
lr: 0.0001
batch_size: 64
iters: 100000
n_critic: 5
gp: False
workers: 8
69 changes: 69 additions & 0 deletions wgan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn

class Discriminator(nn.Module):
def __init__(self, dim):
super().__init__()

self.main = nn.Sequential(
nn.Conv2d(1, dim, 3, bias=False),
nn.LeakyReLU(0.2, True),

nn.Conv2d(dim, 2*dim, 3, bias=False),
nn.BatchNorm2d(2*dim),
nn.LeakyReLU(0.2, True),

nn.Conv2d(2*dim, 4*dim, 3, bias=False),
nn.BatchNorm2d(4*dim),
nn.LeakyReLU(0.2, True),

nn.Conv2d(4*dim, 8*dim, 3, bias=False),
nn.BatchNorm2d(8*dim),
nn.LeakyReLU(0.2, True),

nn.Conv2d(8*dim, 1, 3),
)

def forward(self, input):
out = self.main(input)
out = torch.flatten(out)
return out


class Generator(nn.Module):
def __init__(self, dim):
super().__init__()

self.main = nn.Sequential(
nn.ConvTranspose2d(100, 8*dim, 2, 1, 0, bias=False),
nn.BatchNorm2d(8*dim),
nn.ReLU(True),

nn.ConvTranspose2d(8*dim, 4*dim, 3, 2, 1, 1, bias=False),
nn.BatchNorm2d(4*dim),
nn.ReLU(True),

nn.ConvTranspose2d(4*dim, 2*dim, 3, 2, 1, 1, bias=False),
nn.BatchNorm2d(2*dim),
nn.ReLU(True),

nn.ConvTranspose2d(2*dim, dim, 3, 2, 1, 1, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU(True),

nn.ConvTranspose2d(dim, 1, 3, 2, 1, 1),
nn.Tanh()
)

def forward(self, input):
out = self.main(input)
return out


def discriminator(dim):
model = Discriminator(dim)
return model

def generator(dim):
model = Generator(dim)
return model
61 changes: 61 additions & 0 deletions wgan/train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"from trainer import Trainer\n",
"import wandb"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"args = OmegaConf.load('config.yaml')\n",
"trainer = Trainer(args)\n",
"# trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "prol",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 6 additions & 0 deletions wgan/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from omegaconf import OmegaConf
from trainer import Trainer

args = OmegaConf.load('config.yaml')
trainer = Trainer(args)
trainer.run()
150 changes: 150 additions & 0 deletions wgan/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import math
import wandb

import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from tqdm import tqdm

from model import discriminator, generator
from utils import weights_init, init_torch_seeds, gradient_penalty, init_wandb

class Trainer():
def __init__(self, args):
self.args = args
dataset = torchvision.datasets.MNIST(
root='../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.50], std=[0.25])
])
)
self.dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
pin_memory=True,
num_workers=int(args.workers),
shuffle=True
)
self.device = torch.device(args.device if torch.cuda.is_available() else "cpu")

self.discriminator = discriminator(args.dim).to(self.device)
self.generator = generator(args.dim).to(self.device)

self.discriminator = self.discriminator.apply(weights_init)
self.generator = self.generator.apply(weights_init)

self.optimizer_d = torch.optim.RMSprop(self.discriminator.parameters(), lr=args.lr)
self.optimizer_g = torch.optim.RMSprop(self.generator.parameters(), lr=args.lr)

def get_infinite_batches(self, dataloader):
while True:
for data, _ in dataloader:
yield data

def run(self):
args = self.args

# logging
init_wandb(args, project_name='wgan')

# Set random initialization seed, easy to reproduce.
init_torch_seeds(args.seed)

self.discriminator.train()
self.generator.train()

# fixed noise to generate the same set of images
fixed_noise = torch.randn(args.batch_size, 100, 1, 1, device=self.device)

# infinite data iterator
self.dataiterator = self.get_infinite_batches(self.dataloader)

for it in range(args.iters):

# compute gradients for discriminator
for p in self.discriminator.parameters():
p.requires_grad = True

for _ in range(args.n_critic):

data = next(self.dataiterator)
real_images = data.to(self.device)
batch_size = real_images.size(0)
noise = torch.randn(batch_size, 100, 1, 1, device=self.device)

# (1) Update D network:

# Set D gradients to zero.
self.discriminator.zero_grad()

# clip parameters of D to a range [-c, c]
for p in self.discriminator.parameters():
p.data.clamp_(-0.01, 0.01)

# Pass real images through D
real_output = self.discriminator(real_images)
errD_real = torch.mean(real_output)

# Generate fake image batch with G
fake_images = self.generator(noise)

# Pass fake images through D
fake_output = self.discriminator(fake_images.detach())
errD_fake = torch.mean(fake_output)

# compute the D loss
errD = -errD_real + errD_fake

# compute the D loss gradients
errD.backward()

# Update D
self.optimizer_d.step()

# (2) Update G network:

for p in self.discriminator.parameters():
p.requires_grad = False # to avoid computation

# Set generator gradients to zero
self.generator.zero_grad()

# Generate fake image batch with G
noise = torch.randn(batch_size, 100, 1, 1, device=self.device)
fake_images = self.generator(noise)
fake_output = self.discriminator(fake_images)
errG = -torch.mean(fake_output)
errG.backward()
self.optimizer_g.step()

info = {
'iter': it + 1,
'Loss_D': np.round(errD.item(), 4),
'Loss_G': np.round(errG.item(), 4)
}
print(info)

if args.deploy:
wandb.log(info)

# The image is saved every 1000 epoch.
if (it+1) % 1000 == 0:
vutils.save_image(real_images,
os.path.join("output", "real_samples.png"),
normalize=True)
fake = self.generator(fixed_noise)
vutils.save_image(fake.detach(),
os.path.join("output", f"fake_samples_{it}.png"),
normalize=True)

wandb.finish()

67 changes: 67 additions & 0 deletions wgan/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb
from omegaconf import OmegaConf

def init_torch_seeds(seed: int = 0):
r""" Sets the seed for generating random numbers. Returns a
Args:
seed (int): The desired seed.
"""

# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if seed == 0: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)

def init_wandb(args, project_name):
if args.deploy:
wandb.init(project=project_name)
wandb.run.name = wandb.run.id
wandb.run.save()
wandb.config.update(OmegaConf.to_container(args))

def gradient_penalty(model, real_images, fake_images, device):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake data
alpha = torch.randn((real_images.size(0), 1, 1, 1), device=device)
# Get random interpolation between real and fake data
interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

model_interpolates = model(interpolates)
grad_outputs = torch.ones(model_interpolates.size(), device=device, requires_grad=False)

# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=model_interpolates,
inputs=interpolates,
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
return gradient_penalty

0 comments on commit fde3b4d

Please sign in to comment.