Skip to content

Commit

Permalink
Improve tracing support for Faster-RCNN
Browse files Browse the repository at this point in the history
Summary:
Small fixes to better preserve the origin of shape information during ONNX export. This resolves an issue where an ONNX export would generate fixed sized tensors for tensors that are actually of a dynamic shape.

Pull Request resolved: facebookresearch#2060

Reviewed By: theschnitz

Differential Revision: D24156111

Pulled By: ppwwyyxx

fbshipit-source-id: 34776d4b18be6915b9deea6bd2cbb7680ffc6488
  • Loading branch information
crisp-snakey authored and facebook-github-bot committed Oct 23, 2020
1 parent 5f7881d commit 631c33a
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
4 changes: 2 additions & 2 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def assign_boxes_to_levels(


def _fmt_box_list(box_tensor, batch_index: int):
repeated_index = torch.full(
(len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
repeated_index = torch.full_like(
box_tensor[:, :1], batch_index, dtype=box_tensor.dtype, device=box_tensor.device
)
return cat((repeated_index, box_tensor), dim=1)

Expand Down
14 changes: 14 additions & 0 deletions tests/modeling/test_box2box_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from detectron2.modeling.box_regression import Box2BoxTransform, Box2BoxTransformRotated
from detectron2.utils.env import TORCH_VERSION

logger = logging.getLogger(__name__)

Expand All @@ -29,6 +30,19 @@ def test_reconstruction(self):
dst_boxes_reconstructed = b2b_tfm.apply_deltas(deltas, src_boxes)
assert torch.allclose(dst_boxes, dst_boxes_reconstructed)

@unittest.skipIf(TORCH_VERSION < (1, 8), "Insufficient pytorch version")
def test_apply_deltas_tracing(self):
weights = (5, 5, 10, 10)
b2b_tfm = Box2BoxTransform(weights=weights)

with torch.no_grad():
func = torch.jit.trace(b2b_tfm.apply_deltas, (torch.randn(10, 20), torch.randn(10, 4)))

o = func(torch.randn(10, 20), torch.randn(10, 4))
self.assertEqual(o.shape, (10, 20))
o = func(torch.randn(5, 20), torch.randn(5, 4))
self.assertEqual(o.shape, (5, 20))


def random_rotated_boxes(mean_box, std_length, std_angle, N):
return torch.cat(
Expand Down
68 changes: 68 additions & 0 deletions tests/modeling/test_fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from detectron2.modeling.roi_heads.rotated_fast_rcnn import RotatedFastRCNNOutputLayers
from detectron2.structures import Boxes, Instances, RotatedBoxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.events import EventStorage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,6 +102,73 @@ def test_fast_rcnn_rotated(self):
for name in expected_losses.keys():
assert torch.allclose(losses[name], expected_losses[name])

@unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
def test_predict_boxes_tracing(self):
class Model(torch.nn.Module):
def __init__(self, output_layer):
super(Model, self).__init__()
self._output_layer = output_layer

def forward(self, proposal_deltas, proposal_boxes):
instances = Instances((10, 10))
instances.proposal_boxes = Boxes(proposal_boxes)
return self._output_layer.predict_boxes((None, proposal_deltas), [instances])

box_head_output_size = 8

box_predictor = FastRCNNOutputLayers(
ShapeSpec(channels=box_head_output_size),
box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)),
num_classes=5,
)

model = Model(box_predictor)

from detectron2.export.torchscript_patch import patch_builtin_len

with torch.no_grad(), patch_builtin_len():
func = torch.jit.trace(model, (torch.randn(10, 20), torch.randn(10, 4)))

o = func(torch.randn(10, 20), torch.randn(10, 4))
self.assertEqual(o[0].shape, (10, 20))
o = func(torch.randn(5, 20), torch.randn(5, 4))
self.assertEqual(o[0].shape, (5, 20))
o = func(torch.randn(20, 20), torch.randn(20, 4))
self.assertEqual(o[0].shape, (20, 20))

@unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
def test_predict_probs_tracing(self):
class Model(torch.nn.Module):
def __init__(self, output_layer):
super(Model, self).__init__()
self._output_layer = output_layer

def forward(self, scores, proposal_boxes):
instances = Instances((10, 10))
instances.proposal_boxes = Boxes(proposal_boxes)
return self._output_layer.predict_probs((scores, None), [instances])

box_head_output_size = 8

box_predictor = FastRCNNOutputLayers(
ShapeSpec(channels=box_head_output_size),
box2box_transform=Box2BoxTransform(weights=(10, 10, 5, 5)),
num_classes=5,
)

model = Model(box_predictor)

from detectron2.export.torchscript_patch import patch_builtin_len

with torch.no_grad(), patch_builtin_len():
func = torch.jit.trace(model, (torch.randn(10, 6), torch.rand(10, 4)))
o = func(torch.randn(10, 6), torch.randn(10, 4))
self.assertEqual(o[0].shape, (10, 6))
o = func(torch.randn(5, 6), torch.randn(5, 4))
self.assertEqual(o[0].shape, (5, 6))
o = func(torch.randn(20, 6), torch.randn(20, 4))
self.assertEqual(o[0].shape, (20, 6))


if __name__ == "__main__":
unittest.main()
64 changes: 63 additions & 1 deletion tests/modeling/test_roi_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
import torch

from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.poolers import ROIPooler, _fmt_box_list
from detectron2.structures import Boxes, RotatedBoxes
from detectron2.utils.env import TORCH_VERSION

Expand Down Expand Up @@ -134,6 +134,68 @@ def test_no_images(self):
output = pooler.forward(features, [])
self.assertEqual(output.shape, (0, C, 14, 14))

@unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
def test_fmt_box_list_onnx_export(self):
class Model(torch.nn.Module):
def forward(self, box_tensor):
return _fmt_box_list(box_tensor, 0)

with torch.no_grad():
func = torch.jit.trace(Model(), torch.ones(10, 4))

self.assertEqual(func(torch.ones(10, 4)).shape, (10, 5))
self.assertEqual(func(torch.ones(5, 4)).shape, (5, 5))
self.assertEqual(func(torch.ones(20, 4)).shape, (20, 5))

@unittest.skipIf(TORCH_VERSION < (1, 6), "Insufficient pytorch version")
def test_roi_pooler_onnx_export(self):
class Model(torch.nn.Module):
def __init__(self, roi):
super(Model, self).__init__()
self.roi = roi

def forward(self, x, boxes):
return self.roi([x], [Boxes(boxes)])

pooler_resolution = 14
canonical_level = 4
canonical_scale_factor = 2 ** canonical_level
pooler_scales = (1.0 / canonical_scale_factor,)
sampling_ratio = 0

N, C, H, W = 1, 4, 10, 8
N_rois = 10
std = 11
mean = 0
feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean

rois = self._rand_boxes(
num_boxes=N_rois, x_max=W * canonical_scale_factor, y_max=H * canonical_scale_factor
)

model = Model(
ROIPooler(
output_size=pooler_resolution,
scales=pooler_scales,
sampling_ratio=sampling_ratio,
pooler_type="ROIAlign",
)
)

with torch.no_grad():
func = torch.jit.trace(model, (feature, rois))
o = func(feature, rois)
self.assertEqual(o.shape, (10, 4, 14, 14))
o = func(feature, rois[:5])
self.assertEqual(o.shape, (5, 4, 14, 14))
o = func(
feature,
self._rand_boxes(
num_boxes=20, x_max=W * canonical_scale_factor, y_max=H * canonical_scale_factor
),
)
self.assertEqual(o.shape, (20, 4, 14, 14))


if __name__ == "__main__":
unittest.main()

0 comments on commit 631c33a

Please sign in to comment.