Skip to content

Commit

Permalink
rename generic typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Nov 24, 2021
1 parent e46eb2b commit 8984f44
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 107 deletions.
2 changes: 1 addition & 1 deletion rising/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
* Painting Transforms
"""

from rising.transforms._affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine
from rising.transforms.abstract import (
BaseTransform,
BaseTransformMixin,
PerChannelTransformMixin,
PerSampleTransformMixin,
_AbstractTransform,
)
from rising.transforms.affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine
from rising.transforms.channel import ArgMax, OneHot
from rising.transforms.compose import Compose, DropoutCompose, OneOf
from rising.transforms.crop import CenterCrop, RandomCrop
Expand Down
78 changes: 66 additions & 12 deletions rising/transforms/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from rising.random import AbstractParameter, DiscreteParameter
from rising.utils.mise import fix_seed_cxm, ntuple, nullcxm

T = TypeVar("T")
TYPE_item_seq = Union[T, Sequence[T]]

augment_callable = Callable[..., Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]

__all__ = [
"_AbstractTransform",
"ITEM_or_SEQ",
"TYPE_item_seq",
"BaseTransform",
"PerChannelTransformMixin",
"PerSampleTransformMixin",
"BaseTransformMixin",
]

T = TypeVar("T")
ITEM_or_SEQ = Union[T, Sequence[T]]

augment_callable = Callable[..., Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]


class _AbstractTransform(nn.Module):
"""Base class for all transforms"""
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
keys: Sequence[str] = ("data",),
grad: bool = False,
paired_kw_names: Sequence[str] = (),
augment_fn_names: Sequence[str],
augment_fn_names: Sequence[str] = (),
per_sample: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -227,7 +227,7 @@ def sample_for_batch(self, name: str, batch_size: int) -> Optional[Union[Any, Se
else:
return elem # either a single scalar value or None

def register_paired_attribute(self, name: str, value: ITEM_or_SEQ[T]):
def register_paired_attribute(self, name: str, value: TYPE_item_seq[T]):
if name in self._paired_kw_names:
raise ValueError(f"{name} has been registered in self._pair_kwarg_names")
if name not in self._augment_fn_names:
Expand Down Expand Up @@ -367,14 +367,15 @@ class PerChannelTransformMixin(BaseTransformMixin):
result in different augmentations per channel and key.
"""

def __init__(self, *, per_channel: bool, **kwargs):
def __init__(self, *, per_channel: bool, p: float = 1, **kwargs):
"""
Args:
per_channel:bool parameter to perform per_channel operation
kwargs: base parameters
"""
super().__init__(**kwargs)
self.per_channel = per_channel
self.p = p

def forward(self, **data) -> dict:
"""
Expand All @@ -398,8 +399,61 @@ def forward(self, **data) -> dict:
with self.random_cxm(seed + c):
kwargs = {k: getattr(self, k) for k in self._augment_fn_names if k not in self._paired_kw_names}
kwargs.update(self.get_pair_kwargs(key))

out.append(self.augment_fn(data[key][:, c].unsqueeze(1), **kwargs))
if torch.rand(1).item() < self.p:
out.append(self.augment_fn(data[key][:, c].unsqueeze(1), **kwargs))
else:
out.append(data[key][:, c].unsqueeze(1))
data[key] = torch.cat(out, dim=1)

return data


class PerSamplePerChannelTransformMixin(BaseTransformMixin):
def __init__(self, *, per_channel: bool, p_channel: float = 1, per_sample: bool, p_sample: float = 1, **kwargs):
"""
Args:
per_channel:bool parameter to perform per_channel operation
kwargs: base parameters
"""
super().__init__(**kwargs)
self.per_channel = per_channel
self.p_channel = p_channel

self.per_sample = per_sample
self.p_sample = p_sample

def forward(self, **data) -> dict:
"""
Apply transformation
Args:
data: dict with tensors
Returns:
dict: dict with augmented data
"""
if not self.per_channel:
self.p = self.p_sample
return PerSampleTransformMixin.forward(self, **data)
if not self.per_sample:
self.p = self.p_channel
return PerChannelTransformMixin.forward(self, **data)

seed = int(torch.randint(0, int(1e16), (1,)))

for key in self.keys:
batch_size, channel_dim = data[key].shape[0:2]
for b in range(batch_size):
cur_data = data[key]
processed_batch = []
for c in range(channel_dim):
with self.random_cxm(seed + b + c):
kwargs = {k: getattr(self, k) for k in self._augment_fn_names if k not in self._paired_kw_names}
kwargs.update(self.get_pair_kwargs(key))
if torch.rand(1).item() < self.p:
processed_batch.append(self.augment_fn(cur_data[b, c][None, None, ...], **kwargs))
else:
processed_batch.append(data[key][:, c][None, None, ...])
data[key][b] = torch.cat(processed_batch, dim=1)[0]

return data
67 changes: 26 additions & 41 deletions rising/transforms/_affine.py → rising/transforms/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from rising.transforms.abstract import BaseTransform, BaseTransformMixin, ITEM_or_SEQ
from rising.transforms.abstract import BaseTransform, BaseTransformMixin, TYPE_item_seq
from rising.transforms.functional.affine import AffineParamType, affine_image_transform, parametrize_matrix
from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous
from rising.utils.checktype import check_scalar
Expand Down Expand Up @@ -30,12 +30,11 @@ def __init__(
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
reverse_order: bool = False,
per_sample: bool = True,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -67,19 +66,13 @@ def __init__(
batch order [(D,)H,W]
per_sample: sample different values for each element in the batch.
The transform is still applied in a batched wise fashion.
**kwargs: additional keyword arguments passed to the
affine transform
"""
super().__init__(
augment_fn=affine_image_transform,
keys=keys,
per_sample=per_sample,
augment_fn_names=(),
grad=grad,
seeded=True,
**kwargs,
)
self.kwargs = kwargs
self.matrix = matrix
self.register_sampler("output_size", output_size)
self.adjust_size = adjust_size
Expand Down Expand Up @@ -212,7 +205,6 @@ def __radd__(self, other) -> BaseTransform:
interpolation_mode=self.interpolation_mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
# **self.kwargs,
)

return _StackedAffine(
Expand All @@ -224,7 +216,6 @@ def __radd__(self, other) -> BaseTransform:
interpolation_mode=other.interpolation_mode,
padding_mode=other.padding_mode,
align_corners=other.align_corners,
# **other.kwargs,
)


Expand All @@ -241,11 +232,11 @@ def __init__(
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
reverse_order: bool = False,
**kwargs,
per_sample=True,
):
"""
Args:
Expand Down Expand Up @@ -296,7 +287,7 @@ def __init__(
padding_mode=padding_mode,
align_corners=align_corners,
reverse_order=reverse_order,
**kwargs,
per_sample=per_sample,
)

self.transforms = transforms
Expand Down Expand Up @@ -341,13 +332,12 @@ def __init__(
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[Optional[bool]] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[Optional[bool]] = False,
reverse_order: bool = False,
per_sample: bool = True,
p: float = 1,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -408,8 +398,7 @@ def __init__(
per_sample: sample different values for each element in the batch.
The transform is still applied in a batched wise fashion.
p: float, the probability of applying the transformation on batches.
**kwargs: additional keyword arguments passed to the
affine transform
"""
super().__init__(
keys=keys,
Expand All @@ -422,7 +411,6 @@ def __init__(
reverse_order=reverse_order,
per_sample=per_sample,
p=p,
**kwargs,
)
self.register_sampler("scale", scale)
self.register_sampler("rotation", rotation)
Expand All @@ -447,6 +435,7 @@ def assemble_matrix(self, **data) -> torch.Tensor:
ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim
device = data[self.keys[0]].device
dtype = data[self.keys[0]].dtype

seed = int(torch.randint(0, int(1e6), (1,)))

scale = self.sample_for_batch_with_prob("scale", batch_size, default_value=1.0, seed=seed)
Expand Down Expand Up @@ -502,13 +491,12 @@ def __init__(
degree: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
reverse_order: bool = False,
per_sample: bool = True,
p: float = 1,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -544,8 +532,6 @@ def __init__(
transformation to conform to the pytorch convention:
transformation params order [W,H(,D)] and
batch order [(D,)H,W]
**kwargs: additional keyword arguments passed to the
affine transform
"""
super().__init__(
scale=None,
Expand All @@ -563,7 +549,6 @@ def __init__(
reverse_order=reverse_order,
per_sample=per_sample,
p=p,
**kwargs,
)


Expand All @@ -582,9 +567,9 @@ def __init__(
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
unit: str = "pixel",
reverse_order: bool = False,
per_sample: bool = True,
Expand Down Expand Up @@ -681,9 +666,9 @@ def __init__(
grad: bool = False,
output_size: Optional[tuple] = None,
adjust_size: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
reverse_order: bool = False,
per_sample: bool = True,
p: float = 1,
Expand Down Expand Up @@ -758,9 +743,9 @@ def __init__(
size: Union[int, Tuple[int]],
keys: Sequence[str] = ("data",),
grad: bool = False,
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
reverse_order: bool = False,
**kwargs,
):
Expand Down
20 changes: 10 additions & 10 deletions rising/transforms/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import functional as F

from rising.random.utils import fix_random_seed_ctx
from rising.transforms.abstract import ITEM_or_SEQ, _AbstractTransform
from rising.transforms.abstract import TYPE_item_seq, _AbstractTransform
from rising.transforms.functional import center_crop, random_crop
from rising.transforms.kernel import GaussianSmoothing
from rising.utils.affine import get_batched_eye, matrix_to_homogeneous
Expand All @@ -30,9 +30,9 @@ class GridTransform(_AbstractTransform):
def __init__(
self,
keys: Sequence[str] = ("data",),
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
grad: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -179,9 +179,9 @@ def __init__(
alpha: float,
dim: int = 2,
keys: Sequence[str] = ("data",),
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
grad: bool = False,
per_sample: bool = True,
**kwargs,
Expand Down Expand Up @@ -226,9 +226,9 @@ def __init__(
self,
scale: Tuple[float, float, float],
keys: Sequence[str] = ("data",),
interpolation_mode: ITEM_or_SEQ[str] = "bilinear",
padding_mode: ITEM_or_SEQ[str] = "zeros",
align_corners: ITEM_or_SEQ[bool] = False,
interpolation_mode: TYPE_item_seq[str] = "bilinear",
padding_mode: TYPE_item_seq[str] = "zeros",
align_corners: TYPE_item_seq[bool] = False,
grad: bool = False,
**kwargs,
):
Expand Down
Loading

0 comments on commit 8984f44

Please sign in to comment.