Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
aassxun authored Nov 11, 2022
1 parent af3c984 commit ee401bc
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 0 deletions.
65 changes: 65 additions & 0 deletions EBV_Dim_Experiment/presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode


class ClassificationPresetTrain:
def __init__(
self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))

self.transforms = transforms.Compose(trans)

def __call__(self, img):
return self.transforms(img)


class ClassificationPresetEval:
def __init__(
self,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
):

self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)

def __call__(self, img):
return self.transforms(img)
323 changes: 323 additions & 0 deletions EBV_Dim_Experiment/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
import math
from typing import Tuple

import torch
from torch import Tensor
from torchvision.transforms import functional as F


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)


class RandomCutmix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)



class RandomMixup2(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)

#target_rolled.mul_(1.0 - lambda_param)
#target.mul_(lambda_param).add_(target_rolled)

return batch, target, target_rolled, lambda_param

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)


class RandomCutmix2(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

if not self.inplace:
batch = batch.clone()
target = target.clone()

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

#target_rolled.mul_(1.0 - lambda_param)
#target.mul_(lambda_param).add_(target_rolled)

return batch, target, target_rolled, lambda_param

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)

0 comments on commit ee401bc

Please sign in to comment.