From 8984f44ccbdb5abdbddd323b9faf88634839c7b3 Mon Sep 17 00:00:00 2001 From: jizong Date: Wed, 24 Nov 2021 17:49:30 -0500 Subject: [PATCH] rename generic typing --- rising/transforms/__init__.py | 2 +- rising/transforms/abstract.py | 78 +++++++++++++++++---- rising/transforms/{_affine.py => affine.py} | 67 +++++++----------- rising/transforms/grid.py | 20 +++--- rising/transforms/intensity.py | 18 ++--- rising/transforms/kernel.py | 18 ++--- rising/transforms/pad.py | 6 +- rising/transforms/sitk.py | 6 +- rising/transforms/spatial.py | 18 ++--- rising/utils/mise.py | 12 ++-- tests/transforms/test_affine.py | 2 +- tests/transforms/test_affine_transform.py | 2 +- 12 files changed, 142 insertions(+), 107 deletions(-) rename rising/transforms/{_affine.py => affine.py} (95%) diff --git a/rising/transforms/__init__.py b/rising/transforms/__init__.py index 09c1508..9315217 100644 --- a/rising/transforms/__init__.py +++ b/rising/transforms/__init__.py @@ -19,7 +19,6 @@ * Painting Transforms """ -from rising.transforms._affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine from rising.transforms.abstract import ( BaseTransform, BaseTransformMixin, @@ -27,6 +26,7 @@ PerSampleTransformMixin, _AbstractTransform, ) +from rising.transforms.affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine from rising.transforms.channel import ArgMax, OneHot from rising.transforms.compose import Compose, DropoutCompose, OneOf from rising.transforms.crop import CenterCrop, RandomCrop diff --git a/rising/transforms/abstract.py b/rising/transforms/abstract.py index 135d790..3a070e9 100644 --- a/rising/transforms/abstract.py +++ b/rising/transforms/abstract.py @@ -7,21 +7,21 @@ from rising.random import AbstractParameter, DiscreteParameter from rising.utils.mise import fix_seed_cxm, ntuple, nullcxm +T = TypeVar("T") +TYPE_item_seq = Union[T, Sequence[T]] + +augment_callable = Callable[..., Any] +augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any] + __all__ = [ "_AbstractTransform", - "ITEM_or_SEQ", + "TYPE_item_seq", "BaseTransform", "PerChannelTransformMixin", "PerSampleTransformMixin", "BaseTransformMixin", ] -T = TypeVar("T") -ITEM_or_SEQ = Union[T, Sequence[T]] - -augment_callable = Callable[..., Any] -augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any] - class _AbstractTransform(nn.Module): """Base class for all transforms""" @@ -174,7 +174,7 @@ def __init__( keys: Sequence[str] = ("data",), grad: bool = False, paired_kw_names: Sequence[str] = (), - augment_fn_names: Sequence[str], + augment_fn_names: Sequence[str] = (), per_sample: bool = True, **kwargs, ): @@ -227,7 +227,7 @@ def sample_for_batch(self, name: str, batch_size: int) -> Optional[Union[Any, Se else: return elem # either a single scalar value or None - def register_paired_attribute(self, name: str, value: ITEM_or_SEQ[T]): + def register_paired_attribute(self, name: str, value: TYPE_item_seq[T]): if name in self._paired_kw_names: raise ValueError(f"{name} has been registered in self._pair_kwarg_names") if name not in self._augment_fn_names: @@ -367,7 +367,7 @@ class PerChannelTransformMixin(BaseTransformMixin): result in different augmentations per channel and key. """ - def __init__(self, *, per_channel: bool, **kwargs): + def __init__(self, *, per_channel: bool, p: float = 1, **kwargs): """ Args: per_channel:bool parameter to perform per_channel operation @@ -375,6 +375,7 @@ def __init__(self, *, per_channel: bool, **kwargs): """ super().__init__(**kwargs) self.per_channel = per_channel + self.p = p def forward(self, **data) -> dict: """ @@ -398,8 +399,61 @@ def forward(self, **data) -> dict: with self.random_cxm(seed + c): kwargs = {k: getattr(self, k) for k in self._augment_fn_names if k not in self._paired_kw_names} kwargs.update(self.get_pair_kwargs(key)) - - out.append(self.augment_fn(data[key][:, c].unsqueeze(1), **kwargs)) + if torch.rand(1).item() < self.p: + out.append(self.augment_fn(data[key][:, c].unsqueeze(1), **kwargs)) + else: + out.append(data[key][:, c].unsqueeze(1)) data[key] = torch.cat(out, dim=1) return data + + +class PerSamplePerChannelTransformMixin(BaseTransformMixin): + def __init__(self, *, per_channel: bool, p_channel: float = 1, per_sample: bool, p_sample: float = 1, **kwargs): + """ + Args: + per_channel:bool parameter to perform per_channel operation + kwargs: base parameters + """ + super().__init__(**kwargs) + self.per_channel = per_channel + self.p_channel = p_channel + + self.per_sample = per_sample + self.p_sample = p_sample + + def forward(self, **data) -> dict: + """ + Apply transformation + + Args: + data: dict with tensors + + Returns: + dict: dict with augmented data + """ + if not self.per_channel: + self.p = self.p_sample + return PerSampleTransformMixin.forward(self, **data) + if not self.per_sample: + self.p = self.p_channel + return PerChannelTransformMixin.forward(self, **data) + + seed = int(torch.randint(0, int(1e16), (1,))) + + for key in self.keys: + batch_size, channel_dim = data[key].shape[0:2] + for b in range(batch_size): + cur_data = data[key] + processed_batch = [] + for c in range(channel_dim): + with self.random_cxm(seed + b + c): + kwargs = {k: getattr(self, k) for k in self._augment_fn_names if k not in self._paired_kw_names} + kwargs.update(self.get_pair_kwargs(key)) + if torch.rand(1).item() < self.p: + processed_batch.append(self.augment_fn(cur_data[b, c][None, None, ...], **kwargs)) + else: + processed_batch.append(data[key][:, c][None, None, ...]) + data[key][b] = torch.cat(processed_batch, dim=1)[0] + + return data diff --git a/rising/transforms/_affine.py b/rising/transforms/affine.py similarity index 95% rename from rising/transforms/_affine.py rename to rising/transforms/affine.py index ec3fc45..409b181 100644 --- a/rising/transforms/_affine.py +++ b/rising/transforms/affine.py @@ -2,7 +2,7 @@ import torch -from rising.transforms.abstract import BaseTransform, BaseTransformMixin, ITEM_or_SEQ +from rising.transforms.abstract import BaseTransform, BaseTransformMixin, TYPE_item_seq from rising.transforms.functional.affine import AffineParamType, affine_image_transform, parametrize_matrix from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous from rising.utils.checktype import check_scalar @@ -30,12 +30,11 @@ def __init__( grad: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, reverse_order: bool = False, per_sample: bool = True, - **kwargs, ): """ Args: @@ -67,19 +66,13 @@ def __init__( batch order [(D,)H,W] per_sample: sample different values for each element in the batch. The transform is still applied in a batched wise fashion. - **kwargs: additional keyword arguments passed to the - affine transform """ super().__init__( augment_fn=affine_image_transform, keys=keys, per_sample=per_sample, - augment_fn_names=(), grad=grad, - seeded=True, - **kwargs, ) - self.kwargs = kwargs self.matrix = matrix self.register_sampler("output_size", output_size) self.adjust_size = adjust_size @@ -212,7 +205,6 @@ def __radd__(self, other) -> BaseTransform: interpolation_mode=self.interpolation_mode, padding_mode=self.padding_mode, align_corners=self.align_corners, - # **self.kwargs, ) return _StackedAffine( @@ -224,7 +216,6 @@ def __radd__(self, other) -> BaseTransform: interpolation_mode=other.interpolation_mode, padding_mode=other.padding_mode, align_corners=other.align_corners, - # **other.kwargs, ) @@ -241,11 +232,11 @@ def __init__( grad: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, reverse_order: bool = False, - **kwargs, + per_sample=True, ): """ Args: @@ -296,7 +287,7 @@ def __init__( padding_mode=padding_mode, align_corners=align_corners, reverse_order=reverse_order, - **kwargs, + per_sample=per_sample, ) self.transforms = transforms @@ -341,13 +332,12 @@ def __init__( grad: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[Optional[bool]] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[Optional[bool]] = False, reverse_order: bool = False, per_sample: bool = True, p: float = 1, - **kwargs, ): """ Args: @@ -408,8 +398,7 @@ def __init__( per_sample: sample different values for each element in the batch. The transform is still applied in a batched wise fashion. p: float, the probability of applying the transformation on batches. - **kwargs: additional keyword arguments passed to the - affine transform + """ super().__init__( keys=keys, @@ -422,7 +411,6 @@ def __init__( reverse_order=reverse_order, per_sample=per_sample, p=p, - **kwargs, ) self.register_sampler("scale", scale) self.register_sampler("rotation", rotation) @@ -447,6 +435,7 @@ def assemble_matrix(self, **data) -> torch.Tensor: ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim device = data[self.keys[0]].device dtype = data[self.keys[0]].dtype + seed = int(torch.randint(0, int(1e6), (1,))) scale = self.sample_for_batch_with_prob("scale", batch_size, default_value=1.0, seed=seed) @@ -502,13 +491,12 @@ def __init__( degree: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, reverse_order: bool = False, per_sample: bool = True, p: float = 1, - **kwargs, ): """ Args: @@ -544,8 +532,6 @@ def __init__( transformation to conform to the pytorch convention: transformation params order [W,H(,D)] and batch order [(D,)H,W] - **kwargs: additional keyword arguments passed to the - affine transform """ super().__init__( scale=None, @@ -563,7 +549,6 @@ def __init__( reverse_order=reverse_order, per_sample=per_sample, p=p, - **kwargs, ) @@ -582,9 +567,9 @@ def __init__( grad: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, unit: str = "pixel", reverse_order: bool = False, per_sample: bool = True, @@ -681,9 +666,9 @@ def __init__( grad: bool = False, output_size: Optional[tuple] = None, adjust_size: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, reverse_order: bool = False, per_sample: bool = True, p: float = 1, @@ -758,9 +743,9 @@ def __init__( size: Union[int, Tuple[int]], keys: Sequence[str] = ("data",), grad: bool = False, - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, reverse_order: bool = False, **kwargs, ): diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py index f5f9109..fde5720 100644 --- a/rising/transforms/grid.py +++ b/rising/transforms/grid.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from rising.random.utils import fix_random_seed_ctx -from rising.transforms.abstract import ITEM_or_SEQ, _AbstractTransform +from rising.transforms.abstract import TYPE_item_seq, _AbstractTransform from rising.transforms.functional import center_crop, random_crop from rising.transforms.kernel import GaussianSmoothing from rising.utils.affine import get_batched_eye, matrix_to_homogeneous @@ -30,9 +30,9 @@ class GridTransform(_AbstractTransform): def __init__( self, keys: Sequence[str] = ("data",), - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, grad: bool = False, **kwargs, ): @@ -179,9 +179,9 @@ def __init__( alpha: float, dim: int = 2, keys: Sequence[str] = ("data",), - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, grad: bool = False, per_sample: bool = True, **kwargs, @@ -226,9 +226,9 @@ def __init__( self, scale: Tuple[float, float, float], keys: Sequence[str] = ("data",), - interpolation_mode: ITEM_or_SEQ[str] = "bilinear", - padding_mode: ITEM_or_SEQ[str] = "zeros", - align_corners: ITEM_or_SEQ[bool] = False, + interpolation_mode: TYPE_item_seq[str] = "bilinear", + padding_mode: TYPE_item_seq[str] = "zeros", + align_corners: TYPE_item_seq[bool] = False, grad: bool = False, **kwargs, ): diff --git a/rising/transforms/intensity.py b/rising/transforms/intensity.py index 070faa3..32f244a 100644 --- a/rising/transforms/intensity.py +++ b/rising/transforms/intensity.py @@ -4,9 +4,9 @@ from rising.transforms.abstract import ( BaseTransform, BaseTransformMixin, - ITEM_or_SEQ, PerChannelTransformMixin, PerSampleTransformMixin, + TYPE_item_seq, augment_callable, ) from rising.transforms.functional.intensity import ( @@ -48,8 +48,8 @@ class Clamp(BaseTransformMixin, BaseTransform): def __init__( self, - min: ITEM_or_SEQ[Union[float, AbstractParameter]], - max: ITEM_or_SEQ[Union[float, AbstractParameter]], + min: TYPE_item_seq[Union[float, AbstractParameter]], + max: TYPE_item_seq[Union[float, AbstractParameter]], keys: Sequence = ("data",), grad: bool = False, ): @@ -74,8 +74,8 @@ def __init__( class NormRange(PerSampleTransformMixin, BaseTransform): def __init__( self, - min: ITEM_or_SEQ[Union[float, AbstractParameter]], - max: ITEM_or_SEQ[Union[float, AbstractParameter]], + min: TYPE_item_seq[Union[float, AbstractParameter]], + max: TYPE_item_seq[Union[float, AbstractParameter]], keys: Sequence = ("data",), per_channel: bool = True, per_sample=True, @@ -107,8 +107,8 @@ class NormPercentile(PerSampleTransformMixin, BaseTransform): def __init__( self, - min: ITEM_or_SEQ[Union[float, AbstractParameter]], - max: ITEM_or_SEQ[Union[float, AbstractParameter]], + min: TYPE_item_seq[Union[float, AbstractParameter]], + max: TYPE_item_seq[Union[float, AbstractParameter]], keys: Sequence[str] = ("data",), grad: bool = False, per_channel: bool = True, @@ -204,8 +204,8 @@ class NormMeanStd(PerSampleTransformMixin, BaseTransform): def __init__( self, - mean: ITEM_or_SEQ[Union[float, Sequence[float]]], - std: ITEM_or_SEQ[Union[float, Sequence[float]]], + mean: TYPE_item_seq[Union[float, Sequence[float]]], + std: TYPE_item_seq[Union[float, Sequence[float]]], keys: Sequence[str] = ("data",), per_channel: bool = True, per_sample=True, diff --git a/rising/transforms/kernel.py b/rising/transforms/kernel.py index 531a561..4badf10 100644 --- a/rising/transforms/kernel.py +++ b/rising/transforms/kernel.py @@ -7,7 +7,7 @@ from rising.utils import check_scalar from rising.utils.mise import ntuple -from .abstract import ITEM_or_SEQ, _AbstractTransform +from .abstract import TYPE_item_seq, _AbstractTransform __all__ = ["KernelTransform", "GaussianSmoothing"] @@ -21,10 +21,10 @@ class KernelTransform(_AbstractTransform): def __init__( self, in_channels: int, - kernel_size: ITEM_or_SEQ[int], + kernel_size: TYPE_item_seq[int], dim: int = 2, - stride: ITEM_or_SEQ[int] = 1, - padding: ITEM_or_SEQ[int] = 0, + stride: TYPE_item_seq[int] = 1, + padding: TYPE_item_seq[int] = 0, padding_mode: str = "zero", keys: Sequence[str] = ("data",), grad: bool = False, @@ -129,12 +129,12 @@ class GaussianSmoothing(KernelTransform): def __init__( self, in_channels: int, - kernel_size: ITEM_or_SEQ[int], - std: ITEM_or_SEQ[float], + kernel_size: TYPE_item_seq[int], + std: TYPE_item_seq[float], dim: int = 2, - stride: ITEM_or_SEQ[int] = 1, - padding: ITEM_or_SEQ[int] = 0, - padding_mode: ITEM_or_SEQ[str] = "constant", + stride: TYPE_item_seq[int] = 1, + padding: TYPE_item_seq[int] = 0, + padding_mode: TYPE_item_seq[str] = "constant", keys: Sequence[str] = ("data",), grad: bool = False, **kwargs diff --git a/rising/transforms/pad.py b/rising/transforms/pad.py index 70af350..dd414d6 100644 --- a/rising/transforms/pad.py +++ b/rising/transforms/pad.py @@ -2,7 +2,7 @@ import torch -from rising.transforms.abstract import BaseTransform, BaseTransformMixin, ITEM_or_SEQ +from rising.transforms.abstract import BaseTransform, BaseTransformMixin, TYPE_item_seq from rising.transforms.functional.pad import pad as _pad from rising.utils.mise import ntuple @@ -11,9 +11,9 @@ class Pad(BaseTransformMixin, BaseTransform): def __init__( self, *, - pad_size: ITEM_or_SEQ[int], + pad_size: TYPE_item_seq[int], mode: str = "constant", - pad_value: ITEM_or_SEQ[float], + pad_value: TYPE_item_seq[float], keys: Sequence[str] = ("data",), grad: bool = False, **kwargs diff --git a/rising/transforms/sitk.py b/rising/transforms/sitk.py index 93a0261..f0a2a32 100644 --- a/rising/transforms/sitk.py +++ b/rising/transforms/sitk.py @@ -3,7 +3,7 @@ import torch from rising.random.abstract import AbstractParameter -from rising.transforms.abstract import BaseTransform, ITEM_or_SEQ +from rising.transforms.abstract import BaseTransform, TYPE_item_seq from rising.transforms.functional.sitk import itk2tensor, itk_clip, itk_resample SpacingParamType = Union[ @@ -33,7 +33,7 @@ def __init__( *, pad_value: Union[int, float], keys: Sequence = ("data",), - interpolation: ITEM_or_SEQ[str] = "nearest", + interpolation: TYPE_item_seq[str] = "nearest", ): """ resample simpleitk image given new spacing and padding value @@ -93,7 +93,7 @@ def __init__( self, *, keys: Sequence = ("data",), - dtype: ITEM_or_SEQ[torch.dtype] = torch.float, + dtype: TYPE_item_seq[torch.dtype] = torch.float, insert_dim: int = None, grad: bool = False, **kwargs, diff --git a/rising/transforms/spatial.py b/rising/transforms/spatial.py index ce0e896..b143106 100644 --- a/rising/transforms/spatial.py +++ b/rising/transforms/spatial.py @@ -5,7 +5,7 @@ from torch.multiprocessing import Value from rising.random import AbstractParameter, DiscreteParameter -from rising.transforms.abstract import BaseTransform, BaseTransformMixin, ITEM_or_SEQ, PerSampleTransformMixin +from rising.transforms.abstract import BaseTransform, BaseTransformMixin, PerSampleTransformMixin, TYPE_item_seq from rising.transforms.functional import mirror, resize_native, rot90 __all__ = ["Mirror", "Rot90", "ResizeNative", "Zoom", "ProgressiveResize", "SizeStepScheduler"] @@ -21,7 +21,7 @@ class Mirror(PerSampleTransformMixin, BaseTransform): def __init__( self, *, - dims=ITEM_or_SEQ[Union[int, DiscreteParameter]], + dims=TYPE_item_seq[Union[int, DiscreteParameter]], p_sample: float = 0.5, keys: Sequence[str] = ("data",), grad: bool = False, @@ -51,7 +51,7 @@ class Rot90(PerSampleTransformMixin, BaseTransform): def __init__( self, - dims: ITEM_or_SEQ[Union[Sequence[int], DiscreteParameter]], + dims: TYPE_item_seq[Union[Sequence[int], DiscreteParameter]], keys: Sequence[str] = ("data",), num_rots: DiscreteParameter = DiscreteParameter((0, 1, 2, 3)), p_sample: float = 0.5, @@ -96,8 +96,8 @@ class ResizeNative(BaseTransformMixin, BaseTransform): def __init__( self, size: Union[int, Sequence[int]], - mode: ITEM_or_SEQ[FInterpolation] = FInterpolation.nearest, - align_corners: ITEM_or_SEQ[bool] = None, + mode: TYPE_item_seq[FInterpolation] = FInterpolation.nearest, + align_corners: TYPE_item_seq[bool] = None, preserve_range: bool = False, keys: Sequence[str] = ("data",), grad: bool = False, @@ -142,8 +142,8 @@ class Zoom(BaseTransformMixin, BaseTransform): def __init__( self, scale_factor: Union[Sequence, AbstractParameter] = (0.75, 1.25), - mode: ITEM_or_SEQ[FInterpolation] = FInterpolation.nearest, - align_corners: ITEM_or_SEQ[bool] = None, + mode: TYPE_item_seq[FInterpolation] = FInterpolation.nearest, + align_corners: TYPE_item_seq[bool] = None, preserve_range: bool = False, keys: Sequence[str] = ("data",), grad: bool = False, @@ -190,8 +190,8 @@ class ProgressiveResize(ResizeNative): def __init__( self, scheduler: scheduler_type, - mode: ITEM_or_SEQ[FInterpolation] = FInterpolation.nearest, - align_corners: ITEM_or_SEQ[Optional[bool]] = None, + mode: TYPE_item_seq[FInterpolation] = FInterpolation.nearest, + align_corners: TYPE_item_seq[Optional[bool]] = None, preserve_range: bool = False, keys: Sequence = ("data",), grad: bool = False, diff --git a/rising/utils/mise.py b/rising/utils/mise.py index 09e3954..9c364fc 100644 --- a/rising/utils/mise.py +++ b/rising/utils/mise.py @@ -109,11 +109,7 @@ def on_main_process(): @contextmanager -def fix_seed_cxm(seed: int = 10, **kwargs): - cuda = on_main_process() - with fixed_torch_seed(seed=seed, cuda=cuda): - with fixed_random_seed(seed=seed): - try: - yield - finally: - pass +def fix_seed_cxm(seed: int = 10): + cuda = on_main_process() and torch.cuda.is_available() + with fixed_torch_seed(seed=seed, cuda=cuda), fixed_random_seed(seed=seed): + yield diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index af25941..b1680ad 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -3,7 +3,7 @@ import torch from rising.random import UniformParameter -from rising.transforms._affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine +from rising.transforms.affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous diff --git a/tests/transforms/test_affine_transform.py b/tests/transforms/test_affine_transform.py index ed1645a..2c43bcc 100644 --- a/tests/transforms/test_affine_transform.py +++ b/tests/transforms/test_affine_transform.py @@ -5,7 +5,7 @@ import torch from rising.random import UniformParameter -from rising.transforms._affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine +from rising.transforms.affine import BaseAffine, Resize, Rotate, Scale, Translate, _Affine, _StackedAffine from rising.utils.affine import matrix_to_cartesian, matrix_to_homogeneous