Skip to content

Commit

Permalink
TransformGen -> Augmentation
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D22177876

fbshipit-source-id: 6e9a04dc1738c6056a95a50534093f1e8481f4f9
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jun 24, 2020
1 parent e7182d5 commit c71fbd8
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 79 deletions.
20 changes: 10 additions & 10 deletions detectron2/data/dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class DatasetMapper:

def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
self.crop = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop))
else:
self.crop_gen = None
self.crop = None

self.tfm_gens = utils.build_transform_gen(cfg, is_train)
self.augmentation = utils.build_augmentation(cfg, is_train)

# fmt: off
self.img_format = cfg.INPUT.FORMAT
Expand Down Expand Up @@ -77,20 +77,20 @@ def __call__(self, dataset_dict):

if not dataset_dict.get("annotations", []):
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
([self.crop] if self.crop else []) + self.augmentation, image
)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
if self.crop:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
self.crop.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
image = crop_tfm.apply_image(image)
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
image, transforms = T.apply_transform_gens(self.augmentation, image)
if self.crop:
transforms = crop_tfm + transforms

image_shape = image.shape[:2] # h, w
Expand Down Expand Up @@ -142,7 +142,7 @@ def __call__(self, dataset_dict):
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
# the intersection of original bounding box and the cropping box.
if self.crop_gen and instances.has("gt_masks"):
if self.crop and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)

Expand Down
22 changes: 14 additions & 8 deletions detectron2/data/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,13 @@ def check_metadata_consistency(key, dataset_names):
raise ValueError("Datasets have different metadata '{}'!".format(key))


def build_transform_gen(cfg, is_train):
def build_augmentation(cfg, is_train):
"""
Create a list of :class:`TransformGen` from config.
Create a list of :class:`Augmentation` from config.
Now it includes resizing and flipping.
Returns:
list[TransformGen]
list[Augmentation]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
Expand All @@ -521,9 +521,15 @@ def build_transform_gen(cfg, is_train):
)

logger = logging.getLogger(__name__)
tfm_gens = []
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
augmentation = []
augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
if is_train:
tfm_gens.append(T.RandomFlip())
logger.info("TransformGens used in training: " + str(tfm_gens))
return tfm_gens
augmentation.append(T.RandomFlip())
logger.info("Augmentations used in training: " + str(augmentation))
return augmentation


build_transform_gen = build_augmentation
"""
Alias for backward-compatibility.
"""
42 changes: 27 additions & 15 deletions detectron2/data/transforms/transform_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@
from abc import ABCMeta, abstractmethod
from fvcore.transforms.transform import Transform, TransformList

__all__ = ["TransformGen", "apply_transform_gens"]
__all__ = ["Augmentation", "TransformGen", "apply_transform_gens", "apply_augmentations"]


def check_dtype(img):
assert isinstance(img, np.ndarray), "[TransformGen] Needs an numpy array, but got a {}!".format(
assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
type(img)
)
assert not isinstance(img.dtype, np.integer) or (
img.dtype == np.uint8
), "[TransformGen] Got image of type {}, use uint8 or floating points instead!".format(
), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
img.dtype
)
assert img.ndim in [2, 3], img.ndim


class TransformGen(metaclass=ABCMeta):
class Augmentation(metaclass=ABCMeta):
"""
TransformGen takes an image of type uint8 in range [0, 255], or
Augmentation takes an image of type uint8 in range [0, 255], or
floating point in range [0, 1] or [0, 255] as input.
It creates a :class:`Transform` based on the given image, sometimes with randomness.
Expand All @@ -35,7 +35,7 @@ class TransformGen(metaclass=ABCMeta):
is that the image itself is sufficient to instantiate a transform.
When this assumption is not true, you need to create the transforms by your own.
A list of `TransformGen` can be applied with :func:`apply_transform_gens`.
A list of `Augmentation` can be applied with :func:`apply_augmentations`.
"""

def _init(self, params=None):
Expand All @@ -61,7 +61,7 @@ def _rand_range(self, low=1.0, high=None, size=None):
def __repr__(self):
"""
Produce something like:
"MyTransformGen(field1={self.field1}, field2={self.field2})"
"MyAugmentation(field1={self.field1}, field2={self.field2})"
"""
try:
sig = inspect.signature(self.__init__)
Expand All @@ -87,35 +87,47 @@ def __repr__(self):
__str__ = __repr__


def apply_transform_gens(transform_gens, img):
TransformGen = Augmentation
"""
Alias for Augmentation, since it is something that generates :class:`Transform`s
"""


def apply_augmentations(augmentations, img):
"""
Apply a list of :class:`TransformGen` or :class:`Transform` on the input image, and
Apply a list of :class:`Augmentation` or :class:`Transform` on the input image, and
returns the transformed image and a list of transforms.
We cannot simply create and return all transforms without
applying it to the image, because a subsequent transform may
need the output of the previous one.
Args:
transform_gens (list): list of :class:`TransformGen` or :class:`Transform` instance to
augmentations (list): list of :class:`Augmentation` or :class:`Transform` instance to
be applied.
img (ndarray): uint8 or floating point images with 1 or 3 channels.
Returns:
ndarray: the transformed image
TransformList: contain the transforms that's used.
"""
for g in transform_gens:
assert isinstance(g, (Transform, TransformGen)), g
for aug in augmentations:
assert isinstance(aug, (Transform, Augmentation)), aug

check_dtype(img)

tfms = []
for g in transform_gens:
tfm = g.get_transform(img) if isinstance(g, TransformGen) else g
for aug in augmentations:
tfm = aug.get_transform(img) if isinstance(aug, Augmentation) else aug
assert isinstance(
tfm, Transform
), "TransformGen {} must return an instance of Transform! Got {} instead".format(g, tfm)
), f"Augmentation {aug} must return an instance of Transform! Got {tfm} instead."
img = tfm.apply_image(img)
tfms.append(tfm)
return img, TransformList(tfms)


apply_transform_gens = apply_augmentations
"""
Alias for backward-compatibility.
"""
36 changes: 18 additions & 18 deletions detectron2/data/transforms/transform_gen_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Implement many useful :class:`TransformGen`.
Implement many useful :class:`Augmentation`.
"""
import numpy as np
import sys
Expand All @@ -16,7 +16,7 @@
from PIL import Image

from .transform import ExtentTransform, ResizeTransform, RotationTransform
from .transform_gen import TransformGen
from .transform_gen import Augmentation

__all__ = [
"RandomApply",
Expand All @@ -33,23 +33,23 @@
]


class RandomApply(TransformGen):
class RandomApply(Augmentation):
"""
Randomly apply the wrapper transformation with a given probability.
"""

def __init__(self, transform, prob=0.5):
"""
Args:
transform (Transform, TransformGen): the transform to be wrapped
transform (Transform, Augmentation): the transform to be wrapped
by the `RandomApply`. The `transform` can either be a
`Transform` or `TransformGen` instance.
`Transform` or `Augmentation` instance.
prob (float): probability between 0.0 and 1.0 that
the wrapper transformation is applied
"""
super().__init__()
assert isinstance(transform, (Transform, TransformGen)), (
f"The given transform must either be a Transform or TransformGen instance. "
assert isinstance(transform, (Transform, Augmentation)), (
f"The given transform must either be a Transform or Augmentation instance. "
f"Not {type(transform)}"
)
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
Expand All @@ -59,15 +59,15 @@ def __init__(self, transform, prob=0.5):
def get_transform(self, img):
do = self._rand_range() < self.prob
if do:
if isinstance(self.transform, TransformGen):
if isinstance(self.transform, Augmentation):
return self.transform.get_transform(img)
else:
return self.transform
else:
return NoOpTransform()


class RandomFlip(TransformGen):
class RandomFlip(Augmentation):
"""
Flip the image horizontally or vertically with the given probability.
"""
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_transform(self, img):
return NoOpTransform()


class Resize(TransformGen):
class Resize(Augmentation):
""" Resize image to a target size"""

def __init__(self, shape, interp=Image.BILINEAR):
Expand All @@ -119,7 +119,7 @@ def get_transform(self, img):
)


class ResizeShortestEdge(TransformGen):
class ResizeShortestEdge(Augmentation):
"""
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_transform(self, img):
return ResizeTransform(h, w, newh, neww, self.interp)


class RandomRotation(TransformGen):
class RandomRotation(Augmentation):
"""
This method returns a copy of this image, rotated the given
number of degrees counter clockwise around the given center.
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_transform(self, img):
return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)


class RandomCrop(TransformGen):
class RandomCrop(Augmentation):
"""
Randomly crop a subimage out of an image.
"""
Expand Down Expand Up @@ -274,7 +274,7 @@ def get_crop_size(self, image_size):
NotImplementedError("Unknown crop type {}".format(self.crop_type))


class RandomExtent(TransformGen):
class RandomExtent(Augmentation):
"""
Outputs an image by cropping a random "subrect" of the source image.
Expand Down Expand Up @@ -319,7 +319,7 @@ def get_transform(self, img):
)


class RandomContrast(TransformGen):
class RandomContrast(Augmentation):
"""
Randomly transforms image contrast.
Expand All @@ -345,7 +345,7 @@ def get_transform(self, img):
return BlendTransform(src_image=img.mean(), src_weight=1 - w, dst_weight=w)


class RandomBrightness(TransformGen):
class RandomBrightness(Augmentation):
"""
Randomly transforms image brightness.
Expand All @@ -371,7 +371,7 @@ def get_transform(self, img):
return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)


class RandomSaturation(TransformGen):
class RandomSaturation(Augmentation):
"""
Randomly transforms saturation of an RGB image.
Input images are assumed to have 'RGB' channel order.
Expand Down Expand Up @@ -400,7 +400,7 @@ def get_transform(self, img):
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)


class RandomLighting(TransformGen):
class RandomLighting(Augmentation):
"""
The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
Input images are assumed to have 'RGB' channel order.
Expand Down
4 changes: 2 additions & 2 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self, cfg):
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)

self.transform_gen = T.ResizeShortestEdge(
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)

Expand All @@ -209,7 +209,7 @@ def __call__(self, original_image):
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.transform_gen.get_transform(original_image).apply_image(original_image)
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

inputs = {"image": image, "height": height, "width": width}
Expand Down
10 changes: 5 additions & 5 deletions detectron2/modeling/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ def __call__(self, dataset_dict):
pre_tfm = NoOpTransform()

# Create all combinations of augmentations to use
tfm_gen_candidates = [] # each element is a list[TransformGen]
aug_candidates = [] # each element is a list[Augmentation]
for min_size in self.min_sizes:
resize = ResizeShortestEdge(min_size, self.max_size)
tfm_gen_candidates.append([resize]) # resize only
aug_candidates.append([resize]) # resize only
if self.flip:
flip = RandomFlip(prob=1.0)
tfm_gen_candidates.append([resize, flip]) # resize + flip
aug_candidates.append([resize, flip]) # resize + flip

# Apply all the augmentations
ret = []
for tfm_gen in tfm_gen_candidates:
new_image, tfms = apply_transform_gens(tfm_gen, np.copy(numpy_image))
for aug in aug_candidates:
new_image, tfms = apply_transform_gens(aug, np.copy(numpy_image))
torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1)))

dic = copy.deepcopy(dataset_dict)
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
return True

# Hide some that are deprecated or not intended to be used
_DEPRECATED = {"ResNetBlockBase", "GroupedBatchSampler"}
_DEPRECATED = {"ResNetBlockBase", "GroupedBatchSampler", "build_transform_gen"}
try:
if obj.__doc__.lower().strip().startswith("deprecated") or name in _DEPRECATED:
print("Skipping deprecated object: {}".format(name))
Expand Down
Loading

0 comments on commit c71fbd8

Please sign in to comment.