forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Adding rotations to transform.py Reviewed By: ppwwyyxx Differential Revision: D20488988 fbshipit-source-id: 39cb9ebf9aaa79b284b044d726dce9a55a8f7561
- Loading branch information
1 parent
9369382
commit 63ea900
Showing
2 changed files
with
142 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import numpy as np | ||
import unittest | ||
|
||
from detectron2.data.transforms.transform import RotationTransform | ||
|
||
|
||
class TestRotationTransform(unittest.TestCase): | ||
def assertEqualsArrays(self, a1, a2): | ||
self.assertTrue(np.allclose(a1, a2)) | ||
|
||
def randomData(self, h=5, w=5): | ||
image = np.random.rand(h, w) | ||
coords = np.array([[i, j] for j in range(h + 1) for i in range(w + 1)], dtype=float) | ||
return image, coords, h, w | ||
|
||
def test180(self): | ||
image, coords, h, w = self.randomData(6, 6) | ||
rot = RotationTransform(h, w, 180, expand=False, center=None) | ||
self.assertEqualsArrays(rot.apply_image(image), image[::-1, ::-1]) | ||
rotated_coords = [[w - c[0], h - c[1]] for c in coords] | ||
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) | ||
|
||
def test45_coords(self): | ||
_, coords, h, w = self.randomData(4, 6) | ||
rot = RotationTransform(h, w, 45, expand=False, center=None) | ||
rotated_coords = [ | ||
[(x + y - (h + w) / 2) / np.sqrt(2) + w / 2, h / 2 + (y + (w - h) / 2 - x) / np.sqrt(2)] | ||
for (x, y) in coords | ||
] | ||
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) | ||
|
||
def test90(self): | ||
image, coords, h, w = self.randomData() | ||
rot = RotationTransform(h, w, 90, expand=False, center=None) | ||
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1]) | ||
rotated_coords = [[c[1], w - c[0]] for c in coords] | ||
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) | ||
|
||
def test90_expand(self): # non-square image | ||
image, coords, h, w = self.randomData(h=5, w=8) | ||
rot = RotationTransform(h, w, 90, expand=True, center=None) | ||
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1]) | ||
rotated_coords = [[c[1], w - c[0]] for c in coords] | ||
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords) | ||
|
||
def test_center_expand(self): | ||
# center has no effect if expand=True because it only affects shifting | ||
image, coords, h, w = self.randomData(h=5, w=8) | ||
angle = np.random.randint(360) | ||
rot1 = RotationTransform(h, w, angle, expand=True, center=None) | ||
rot2 = RotationTransform(h, w, angle, expand=True, center=(0, 0)) | ||
rot3 = RotationTransform(h, w, angle, expand=True, center=(h, w)) | ||
rot4 = RotationTransform(h, w, angle, expand=True, center=(2, 5)) | ||
for r1 in [rot1, rot2, rot3, rot4]: | ||
for r2 in [rot1, rot2, rot3, rot4]: | ||
self.assertEqualsArrays(r1.apply_image(image), r2.apply_image(image)) | ||
self.assertEqualsArrays(r1.apply_coords(coords), r2.apply_coords(coords)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |