Skip to content

Commit

Permalink
Improved utilites, adds examples, tests (pytorch#3594)
Browse files Browse the repository at this point in the history
* start adding tests

* add return type and doc

* adds tests

* add no fill tests

* add rgb test

* check inplace

* bug fix

* bug fix

* rewrite make grid

* add plotting demos

* rename file

* remove

* updt

* Add viz

* updt

* update readme, add links

* complte bounding boxes

* Complete the examples!

* link fix

* link fixed

Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
oke-aditya and fmassa authored Mar 30, 2021
1 parent 20a771e commit 66d777e
Show file tree
Hide file tree
Showing 5 changed files with 768 additions and 8 deletions.
4 changes: 4 additions & 0 deletions examples/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/video_api.ipynb)
[Example of VideoAPI](https://github.com/pytorch/vision/blob/master/examples/python/video_api.ipynb)
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)
[Example of Visualization Utils](https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb)


Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
Expand All @@ -16,3 +18,5 @@ features:
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)

Furthermore, previously we used to provide a very high-level API for video decoding which left little control to the user. We're now expanding that API (and replacing it in the future) with a lower-level API that allows the user a frame-based access to a video.

Torchvision also provides utilities to visualize results. You can make grid of images, plot bounding boxes as well as segmentation masks. Thse utilities work standalone as well as with torchvision models for detection and segmentation.
683 changes: 683 additions & 0 deletions examples/python/visualization_utils.ipynb

Large diffs are not rendered by default.

Binary file added test/assets/fakedata/draw_boxes_vanilla.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 56 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torchvision.transforms.functional as F
from PIL import Image

boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)

masks = torch.tensor([
[
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
Expand Down Expand Up @@ -106,8 +109,8 @@ def test_save_image_single_pixel_file_object(self):

def test_draw_boxes(self):
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
img_cp = img.clone()
boxes_cp = boxes.clone()
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
result = utils.draw_bounding_boxes(img, boxes, labels=labels, colors=colors, fill=True)
Expand All @@ -119,9 +122,41 @@ def test_draw_boxes(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

def test_draw_boxes_vanilla(self):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item())
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())

def test_draw_invalid_boxes(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0],
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
self.assertRaises(TypeError, utils.draw_bounding_boxes, img_tp, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)

def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)

Expand All @@ -134,9 +169,14 @@ def test_draw_segmentation_masks_colors(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
result = utils.draw_segmentation_masks(img, masks, colors=None)

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
Expand All @@ -148,6 +188,20 @@ def test_draw_segmentation_masks_no_colors(self):

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())

def test_draw_invalid_masks(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)

self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)


if __name__ == '__main__':
Expand Down
31 changes: 25 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def make_grid(
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
"""
Make a grid of images.
Args:
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
Expand All @@ -37,9 +38,12 @@ def make_grid(
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Example:
See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
Returns:
grid (Tensor): the tensor containing grid of images.
Example:
See this notebook
`here <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
Expand Down Expand Up @@ -117,7 +121,8 @@ def save_image(
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a given Tensor into an image file.
"""
Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
Expand Down Expand Up @@ -150,7 +155,7 @@ def draw_bounding_boxes(
"""
Draws bounding boxes on given image.
The values of the input image should be uint8 between 0 and 255.
If filled, Resulting Tensor should be saved as PNG image.
If fill is True, Resulting Tensor should be saved as PNG image.
Args:
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
Expand All @@ -166,6 +171,13 @@ def draw_bounding_boxes(
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`,
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
font_size (int): The requested font size in points.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
Example:
See this notebook
`linked <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""

if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -209,7 +221,7 @@ def draw_bounding_boxes(
if labels is not None:
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)


@torch.no_grad()
Expand All @@ -230,6 +242,13 @@ def draw_segmentation_masks(
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
Example:
See this notebook
`attached <https://github.com/pytorch/vision/blob/master/examples/python/visualization_utils.ipynb>`_
"""

if not isinstance(image, torch.Tensor):
Expand Down

0 comments on commit 66d777e

Please sign in to comment.