Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
GunwooHan committed Jun 2, 2024
2 parents 44619cb + 9d703b7 commit 867a38a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 40 deletions.
3 changes: 0 additions & 3 deletions data/bapps.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def train_dataloader(self):
self.bapps_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=True
)
Expand All @@ -86,7 +85,6 @@ def val_dataloader(self):
self.bapps_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=False
)
Expand All @@ -96,7 +94,6 @@ def test_dataloader(self):
self.bapps_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=False
)
Expand Down
72 changes: 67 additions & 5 deletions e_latent_lpips/e_latent_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from typing import Any

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.optim.lr_scheduler import StepLR, ExponentialLR, ReduceLROnPlateau, CosineAnnealingWarmRestarts, \
CosineAnnealingLR, PolynomialLR
from torch.optim.lr_scheduler import StepLR, ExponentialLR, ReduceLROnPlateau, CosineAnnealingLR

from .pretrained_networks import LatentVGG16, VGG16

Expand Down Expand Up @@ -95,6 +94,8 @@ def configure_optimizers(self):
optimizer = optim.Rprop(self.parameters(), lr=self.args.learning_rate)
else:
raise ValueError(f"Unsupported optimizer type: {self.args.optimizer}")


if self.args.lr_scheduler == 'constant':
return optimizer
elif self.args.lr_scheduler == 'step':
Expand All @@ -104,7 +105,7 @@ def configure_optimizers(self):
elif self.args.lr_scheduler == 'cosine_anneling':
scheduler = CosineAnnealingLR(optimizer, T_max=self.args.t_max, eta_min=1e-6)
elif self.args.lr_scheduler == 'cosine_anneling_warmup_restarts':
scheduler = CosineAnnealingWarmRestarts(optimizer, self.args.t_max, self.args.t_mult, eta_min=1e-6)
scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=10, T_mult=1, eta_max=0.001, T_up=5, gamma=self.args.gamma)
elif self.args.lr_scheduler == 'reduce_on_plateau':
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.args.gamma, patience=self.args.patience)
else:
Expand All @@ -113,7 +114,7 @@ def configure_optimizers(self):
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_loss' # This is necessary for ReduceLROnPlateau
'monitor': 'val_total_loss' # This is necessary for ReduceLROnPlateau
}

def load_checkpoint(self, model_path):
Expand Down Expand Up @@ -255,3 +256,64 @@ def forward(self, d0, d1, judge):
per = judge
self.logit = self.net.forward(d0, d1)
return self.loss(self.logit, per)


import math
from torch.optim.lr_scheduler import LRScheduler


class CosineAnnealingWarmUpRestarts(LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if T_up < 0 or not isinstance(T_up, int):
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
self.T_0 = T_0
self.T_mult = T_mult
self.base_eta_max = eta_max
self.eta_max = eta_max
self.T_up = T_up
self.T_i = T_0
self.gamma = gamma
self.cycle = 0
self.T_cur = last_epoch
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)

def get_lr(self):
if self.T_cur == -1:
return self.base_lrs
elif self.T_cur < self.T_up:
return [(self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.eta_max - base_lr) * (
1 + math.cos(math.pi * (self.T_cur - self.T_up) / (self.T_i - self.T_up))) / 2
for base_lr in self.base_lrs]

def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.cycle += 1
self.T_cur = self.T_cur - self.T_i
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
else:
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
self.cycle = epoch // self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.cycle = n
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch

self.eta_max = self.base_eta_max * (self.gamma ** self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
39 changes: 7 additions & 32 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

from e_latent_lpips import e_latent_lpips
Expand All @@ -12,32 +12,20 @@
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=1e-5)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--wandb', type=bool, default=True)
parser.add_argument('--optimizer', type=str, default="sgd")
parser.add_argument('--step_size', type=int, default=1)
parser.add_argument('--step_size', type=int, default=10)
parser.add_argument('--gamma', type=float, default=0.5)
parser.add_argument('--lr_scheduler', type=str, default="cosine_anneling")
parser.add_argument('--lr_scheduler', type=str, default="step")
parser.add_argument('--factor', type=float, default=0.5)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--t_max', type=int, default=5)
parser.add_argument('--t_mult', type=int, default=5)
parser.add_argument('--swa_lr', type=float, default=1e-4)

# parser.add_argument('--crop_image_size', type=int, default=512)
# parser.add_argument('--ShiftScaleRotateMode', type=int, default=4)
# parser.add_argument('--ShiftScaleRotate', type=float, default=0.2)
# parser.add_argument('--horizontal_flip', type=float, default=0.2)
# parser.add_argument('--rotate_90_degrees', type=float, default=0.2)
# parser.add_argument('--VerticalFlip', type=float, default=0.2)

parser.add_argument('--blit', type=bool, default=False)
parser.add_argument('--geometric', type=bool, default=False)
parser.add_argument('--color', type=bool, default=False)
parser.add_argument('--cutout', type=bool, default=False)

parser.add_argument('--model', type=str, default='vgg')
parser.add_argument('--checkpoints_dir', type=str, default='checkpoints')
parser.add_argument('--pretrained', type=bool, default=True)
Expand All @@ -54,7 +42,6 @@
'val/superres'])

args = parser.parse_args()
print(args)

pl.seed_everything(args.seed)

Expand All @@ -70,22 +57,10 @@
)

swa_callback = StochasticWeightAveraging(swa_lrs=args.swa_lr)
lr_callback = LearningRateMonitor(logging_interval='step')

if args.wandb:
tag = []

if args.blit:
tag.append('blit')
if args.geometric:
tag.append('geometric')
if args.cutout:
tag.append('cutout')
if args.color:
tag.append('color')

if args.latent_mode:
tag.append('Latent')

tag += args.train_dataset_dir
tag += args.val_dataset_dir

Expand All @@ -106,7 +81,7 @@
trainer = pl.Trainer(
devices=1,
max_epochs=args.epochs,
callbacks=[checkpoint_callback, swa_callback],
callbacks=[checkpoint_callback, swa_callback, lr_callback],
logger=wandb_logger if args.wandb else None,
)
trainer.fit(model, dm)

0 comments on commit 867a38a

Please sign in to comment.