Skip to content

Commit

Permalink
FEAT: add color augmentation and add argparser option
Browse files Browse the repository at this point in the history
  • Loading branch information
GunwooHan committed Jun 7, 2024
1 parent 2736202 commit 8609ead
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 17 deletions.
71 changes: 71 additions & 0 deletions augmentation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch.nn as nn
import torch

import torchvision


class ContrastAdjust(nn.Module):
def __init__(self, contrast_range=0):
super().__init__()
# assert len(contrast_range) == 2, "Input tensor must include min, max value"
if len(contrast_range) == 1:
self.contrast_range = [0, contrast_range]
elif len(contrast_range) == 2:
self.contrast_range = contrast_range
else:
raise ValueError("Input tensor must include min, max value, only support 1 or 2 dimenstion")
self.contrast_range = contrast_range

def __call__(self, input_tensor):
batch_size, channels, height, width = input_tensor.shape
self.contrast_factor = torch.rand(1) * (self.contrast_range[1] - self.contrast_range[0]) + self.contrast_range[
0]

output_tensor = torch.empty_like(input_tensor)
for i in range(batch_size):
img = input_tensor[i]
mean = img.mean([1, 2], keepdim=True)
adjusted_img = (img - mean) * self.contrast_factor + mean
output_tensor[i] = adjusted_img.clamp(min=-1, max=1)

return output_tensor


class BrightnessAdjust(nn.Module):
def __init__(self, brightness_range=(0, 0.5)):
super().__init__()
# assert len(contrast_range) == 2, "Input tensor must include min, max value"
if len(brightness_range) == 1:
self.brightness_range = [0, brightness_range]
elif len(brightness_range) == 2:
self.brightness_range = brightness_range
else:
raise ValueError("Input tensor must include min, max value, only support 1 or 2 dimenstion")
self.brightness_range = brightness_range

def __call__(self, input_tensor):
batch_size, channels, height, width = input_tensor.shape
self.brightness_factor = torch.rand(1) * (self.brightness_range[1] - self.brightness_range[0]) + \
self.brightness_range[0]

output_tensor = torch.empty_like(input_tensor)
for i in range(batch_size):
img = input_tensor[i]
max_value = img.max()
adjusted_img = img * self.brightness_factor
output_tensor[i] = adjusted_img.clamp(max=max_value)

return output_tensor


if __name__ == '__main__':
inputs = torch.randn(1, 4, 64, 64)

outputs = []

# transform = ContrastAdjust([0, 1])
transform = BrightnessAdjust([0, 1])
for i in range(8):
outputs.append(transform(inputs))

torchvision.utils.save_image(torch.cat([inputs] + outputs, dim=0), "aug_comp.png", nrow=9, normalize=True)
68 changes: 51 additions & 17 deletions single_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from PIL import Image
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DModel
from pytorch_lightning.loggers import WandbLogger
from torch import optim
from torch import optim, nn
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, LearnedPerceptualImagePatchSimilarity
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR, ExponentialLR, ReduceLROnPlateau

from e_latent_lpips import e_latent_lpips
from augmentation_test import BrightnessAdjust, ContrastAdjust

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
Expand Down Expand Up @@ -93,14 +94,31 @@ def __init__(self, args):
if self.ensemble_mode:
self.ensemble_transform = self.create_ensemble_transform()

def model_trainalbe_set(self):
for name, param in self.text_encoder.named_parameters():
param.requires_grad = False
for name, param in self.vae.named_parameters():
param.requires_grad = False
for name, param in self.lpips.named_parameters():
param.requires_grad = False

def create_ensemble_transform(self):
transform = []
transform.append(transforms.RandomHorizontalFlip())
transform.append(transforms.RandomRotation(degrees=[90, 90],
interpolation=torchvision.transforms.InterpolationMode.BILINEAR))
transform.append(transforms.RandomAffine(degrees=30, translate=(0.2, 0.2)))
transform.append(transforms.RandomErasing(p=1.0, scale=(0.02, 0.33), ratio=(0.3, 3.3)))
transform.append(transforms.RandomResizedCrop(size=64, scale=(0.8, 1.2), ratio=(1.0, 1.0)))
if self.args.flip:
transform.append(transforms.RandomHorizontalFlip())
if self.args.rotation:
transform.append(transforms.RandomRotation(degrees=[90, 90],
interpolation=torchvision.transforms.InterpolationMode.BILINEAR))
if self.args.translation:
transform.append(transforms.RandomAffine(degrees=30, translate=(0.2, 0.2)))
if self.args.cutout:
transform.append(transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)))
if self.args.resize:
transform.append(transforms.RandomResizedCrop(size=64, scale=(0.8, 1.2), ratio=(1.0, 1.0)))
if self.args.bright:
transform.append(BrightnessAdjust(1))
if self.args.contrast:
transform.append(ContrastAdjust(1))
return transforms.Compose(transform)
# Isotropic scaling (uniform scaling)

Expand All @@ -113,14 +131,6 @@ def create_ensemble_transform(self):
# # Random contrast
# random_contrast = transforms.ColorJitter(contrast=0.5)

def model_trainalbe_set(self):
for name, param in self.text_encoder.named_parameters():
param.requires_grad = False
for name, param in self.vae.named_parameters():
param.requires_grad = False
for name, param in self.lpips.named_parameters():
param.requires_grad = False

def training_step(self, batch, batch_idx):
x, y = batch
y = y * 0.18215
Expand All @@ -143,8 +153,10 @@ def training_step(self, batch, batch_idx):
transformed_y_y_hat = self.ensemble_transform(torch.cat([y, y_hat], dim=0))
y_transformed = transformed_y_y_hat[:y.size(0), ...]
y_hat_transformed = transformed_y_y_hat[y.size(0):, ...]
lpips_loss = self.lpips(2 * (y_transformed - y_transformed.min()) / (y_transformed.max() - y_transformed.min()) - 1,
2 * (y_hat_transformed - y_hat_transformed.min()) / (y_hat_transformed.max() - y_hat_transformed.min()) - 1).flatten()
lpips_loss = self.lpips(
2 * (y_transformed - y_transformed.min()) / (y_transformed.max() - y_transformed.min()) - 1,
2 * (y_hat_transformed - y_hat_transformed.min()) / (
y_hat_transformed.max() - y_hat_transformed.min()) - 1).flatten()
else:
lpips_loss = self.lpips(2 * (y - y.min()) / (y.max() - y.min()) - 1,
2 * (y_hat - y_hat.min()) / (y_hat.max() - y_hat.min()) - 1).flatten()
Expand Down Expand Up @@ -275,6 +287,28 @@ def __getitem__(self, idx):
return self.noise_image, self.target_image


class ContrastAdjust(nn.Module):
def __init__(self, contrast_factor=0):
super().__init__()
self.contrast_factor = contrast_factor

def __call__(self, input_tensor):
assert len(input_tensor.shape) == 4, "Input tensor must be 4-dimensional"

batch_size, channels, height, width = input_tensor.shape

# 4차원 입력 텐서에 대해 contrast 조정을 적용합니다.
output_tensor = torch.empty_like(input_tensor)
for i in range(batch_size):
img = input_tensor[i]
# 이미지를 각 채널별로 분리합니다.
mean = img.mean([1, 2], keepdim=True)
adjusted_img = (img - mean) * self.contrast_factor + mean
output_tensor[i] = adjusted_img.clamp(0, 1) # 값의 범위를 [0, 1]로 클램프합니다.

return output_tensor


if __name__ == '__main__':
pl.seed_everything(args.seed)

Expand Down

0 comments on commit 8609ead

Please sign in to comment.