Skip to content

Commit

Permalink
Adding rotations to transform.py
Browse files Browse the repository at this point in the history
Summary: Adding rotations to transform.py

Reviewed By: ppwwyyxx

Differential Revision: D20488988

fbshipit-source-id: 39cb9ebf9aaa79b284b044d726dce9a55a8f7561
  • Loading branch information
MarcSzafraniec authored and facebook-github-bot committed Apr 9, 2020
1 parent 9369382 commit 63ea900
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
82 changes: 81 additions & 1 deletion detectron2/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from fvcore.transforms.transform import HFlipTransform, NoOpTransform, Transform
from PIL import Image

__all__ = ["ExtentTransform", "ResizeTransform"]
try:
import cv2 # noqa
except ImportError:
# OpenCV is an optional dependency at the moment
pass

__all__ = ["ExtentTransform", "ResizeTransform", "RotationTransform"]


class ExtentTransform(Transform):
Expand Down Expand Up @@ -94,6 +100,80 @@ def apply_segmentation(self, segmentation):
return segmentation


class RotationTransform(Transform):
"""
This method returns a copy of this image, rotated the given
number of degrees counter clockwise around its center.
"""

def __init__(self, h, w, angle, expand=True, center=None, interp=None):
"""
Args:
h, w (int): original image size
angle (float): degrees for rotation
expand (bool): choose if the image should be resized to fit the whole
rotated image (default), or simply cropped
center (tuple (width, height)): coordinates of the rotation center
if left to None, the center will be fit to the center of each image
center has no effect if expand=True because it only affects shifting
interp: cv2 interpolation method, default cv2.INTER_LINEAR
"""
super().__init__()
image_center = np.array((w / 2, h / 2))
if center is None:
center = image_center
if interp is None:
interp = cv2.INTER_LINEAR
abs_cos, abs_sin = abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle)))
if expand:
# find the new width and height bounds
bound_w, bound_h = np.rint(
[h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
).astype(int)
else:
bound_w, bound_h = w, h

self._set_attributes(locals())
self.rm_coords = self.create_rotation_matrix()
# Needed because of this problem https://github.com/opencv/opencv/issues/11784
self.rm_image = self.create_rotation_matrix(offset=-0.5)

def apply_image(self, img, interp=cv2.INTER_LINEAR):
"""
img should be a numpy array, formatted as Height * Width * Nchannels
"""
if len(img) == 0:
return img
assert img.shape[:2] == (self.h, self.w)
interp = interp if interp is not None else self.interp
return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)

def apply_coords(self, coords):
"""
coords should be a N * 2 array-like, containing N couples of (x, y) points
"""
if len(coords) == 0:
return coords
coords = np.asarray(coords, dtype=float)
return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]

def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
return segmentation

def create_rotation_matrix(self, offset=0):
center = (self.center[0] + offset, self.center[1] + offset)
rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
if self.expand:
# Find the coordinates of the center of rotation in the new image
# The only point for which we know the future coordinates is the center of the image
rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
# shift the rotation center to the new coordinates
rm[:, 2] += new_center
return rm


def HFlip_rotated_box(transform, rotated_boxes):
"""
Apply the horizontal flip transform on rotated boxes.
Expand Down
61 changes: 61 additions & 0 deletions tests/test_rotation_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import unittest

from detectron2.data.transforms.transform import RotationTransform


class TestRotationTransform(unittest.TestCase):
def assertEqualsArrays(self, a1, a2):
self.assertTrue(np.allclose(a1, a2))

def randomData(self, h=5, w=5):
image = np.random.rand(h, w)
coords = np.array([[i, j] for j in range(h + 1) for i in range(w + 1)], dtype=float)
return image, coords, h, w

def test180(self):
image, coords, h, w = self.randomData(6, 6)
rot = RotationTransform(h, w, 180, expand=False, center=None)
self.assertEqualsArrays(rot.apply_image(image), image[::-1, ::-1])
rotated_coords = [[w - c[0], h - c[1]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)

def test45_coords(self):
_, coords, h, w = self.randomData(4, 6)
rot = RotationTransform(h, w, 45, expand=False, center=None)
rotated_coords = [
[(x + y - (h + w) / 2) / np.sqrt(2) + w / 2, h / 2 + (y + (w - h) / 2 - x) / np.sqrt(2)]
for (x, y) in coords
]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)

def test90(self):
image, coords, h, w = self.randomData()
rot = RotationTransform(h, w, 90, expand=False, center=None)
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1])
rotated_coords = [[c[1], w - c[0]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)

def test90_expand(self): # non-square image
image, coords, h, w = self.randomData(h=5, w=8)
rot = RotationTransform(h, w, 90, expand=True, center=None)
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1])
rotated_coords = [[c[1], w - c[0]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)

def test_center_expand(self):
# center has no effect if expand=True because it only affects shifting
image, coords, h, w = self.randomData(h=5, w=8)
angle = np.random.randint(360)
rot1 = RotationTransform(h, w, angle, expand=True, center=None)
rot2 = RotationTransform(h, w, angle, expand=True, center=(0, 0))
rot3 = RotationTransform(h, w, angle, expand=True, center=(h, w))
rot4 = RotationTransform(h, w, angle, expand=True, center=(2, 5))
for r1 in [rot1, rot2, rot3, rot4]:
for r2 in [rot1, rot2, rot3, rot4]:
self.assertEqualsArrays(r1.apply_image(image), r2.apply_image(image))
self.assertEqualsArrays(r1.apply_coords(coords), r2.apply_coords(coords))


if __name__ == "__main__":
unittest.main()

0 comments on commit 63ea900

Please sign in to comment.