Skip to content

Commit

Permalink
Make padding shape tracable.
Browse files Browse the repository at this point in the history
Summary:
Previously `max_size` will be hardcoded in traced ONNX model since some dataflow/computation is outside PyTorch.

The entire chunk between `[Sub, Div]`(whitening) and `Pad` shows how this shape inferencing works, with `size_divisibility==32` for this backbone.
`ceil()` is replaced with `[Add, Div, Mul]` in the middle, i.e. `(n + (m - 1)) // m * m` with constant folding for `m-1`.
<img width="479" alt="image" src="https://user-images.githubusercontent.com/5203025/79173373-e9d48780-7e29-11ea-880b-dc8612e027e2.png">
Pull Request resolved: facebookresearch#1208

Reviewed By: rbgirshick

Differential Revision: D21076153

Pulled By: ppwwyyxx

fbshipit-source-id: c6d37114dd5d9cb0ec65f561fba7961ce26d74ec
  • Loading branch information
xkszltl authored and facebook-github-bot committed Apr 21, 2020
1 parent 9ced00a commit 1e3bd77
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
27 changes: 19 additions & 8 deletions detectron2/structures/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,26 @@ def from_tensors(
assert isinstance(t, torch.Tensor), type(t)
assert t.shape[1:-2] == tensors[0].shape[1:-2], t.shape
# per dimension maximum (H, W) or (C_1, ..., C_K, H, W) where K >= 1 among all tensors
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
max_size = (
# In tracing mode, x.shape[i] is Tensor, and should not be converted
# to int: this will cause the traced graph to have hard-coded shapes.
# Instead we should make max_size a Tensor that depends on these tensors.
# Using torch.stack twice seems to be the best way to convert
# list[list[ScalarTensor]] to a Tensor
torch.stack(
[
torch.stack([torch.as_tensor(dim) for dim in size])
for size in [tuple(img.shape) for img in tensors]
]
)
.max(0)
.values
)

if size_divisibility > 0:
import math

stride = size_divisibility
max_size = list(max_size) # type: ignore
max_size[-2] = int(math.ceil(max_size[-2] / stride) * stride) # type: ignore
max_size[-1] = int(math.ceil(max_size[-1] / stride) * stride) # type: ignore
max_size = tuple(max_size)
# the last two dims are H,W, both subject to divisibility requirement
max_size = torch.cat([max_size[:-2], (max_size[-2:] + (stride - 1)) // stride * stride])

image_sizes = [tuple(im.shape[-2:]) for im in tensors]

Expand All @@ -94,7 +104,8 @@ def from_tensors(
padded = F.pad(tensors[0], padding_size, value=pad_value)
batched_imgs = padded.unsqueeze_(0)
else:
batch_shape = (len(tensors),) + max_size
# max_size can be a tensor in tracing mode, therefore use tuple()
batch_shape = (len(tensors),) + tuple(max_size)
batched_imgs = tensors[0].new_full(batch_shape, pad_value)
for img, pad_img in zip(tensors, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import copy
import numpy as np
import unittest
from typing import Sequence
import pycocotools.mask as mask_util
import torch

from detectron2.data import detection_utils
from detectron2.data import transforms as T
from detectron2.structures import BitMasks, BoxMode
from detectron2.structures import BitMasks, BoxMode, ImageList


class TestTransformAnnotations(unittest.TestCase):
Expand Down Expand Up @@ -114,3 +116,32 @@ def test_gen_crop_outside_boxes(self):
instance = {"bbox": [10, 10, 100, 100], "bbox_mode": BoxMode.XYXY_ABS}
with self.assertRaises(AssertionError):
detection_utils.gen_crop_transform_with_instance((10, 10), (15, 15), instance)

def test_imagelist_padding_shape(self):
class TensorToImageList(torch.nn.Module):
def forward(self, tensors: Sequence[torch.Tensor]):
return ImageList.from_tensors(tensors, 4).tensor

func = torch.jit.trace(
TensorToImageList(), ([torch.ones((3, 10, 10), dtype=torch.float32)],)
)
ret = func([torch.ones((3, 15, 20), dtype=torch.float32)])
self.assertEqual(list(ret.shape), [1, 3, 16, 20], str(ret.shape))

func = torch.jit.trace(
TensorToImageList(),
(
[
torch.ones((3, 16, 10), dtype=torch.float32),
torch.ones((3, 13, 11), dtype=torch.float32),
],
),
)
ret = func(
[
torch.ones((3, 25, 20), dtype=torch.float32),
torch.ones((3, 10, 10), dtype=torch.float32),
]
)
# does not support calling with different #images
self.assertEqual(list(ret.shape), [2, 3, 28, 20], str(ret.shape))

0 comments on commit 1e3bd77

Please sign in to comment.