Skip to content

Commit

Permalink
Make RPN scriptable.
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#1765

Reviewed By: rbgirshick

Differential Revision: D22527319

Pulled By: ppwwyyxx

fbshipit-source-id: f79b99b10563f6817dcc1ef99d0ce52632f8c0c4
  • Loading branch information
chenbohua3 authored and facebook-github-bot committed Jul 28, 2020
1 parent 8b0bb4b commit 68df102
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ run_unittests: &run_unittests
- run:
name: Run Unit Tests
command: |
pytest -n 4 -v tests
pytest -n 1 -v tests # parallel causes some random failures
# -------------------------------------------------------------------------------------
# Jobs to run
Expand Down
5 changes: 4 additions & 1 deletion detectron2/export/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ Please see [documentation](https://detectron2.readthedocs.io/tutorials/deploymen

### Acknowledgements

Thanks to Mobile Vision team at Facebook for developing the conversion tools.
Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.

Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
help export Detectron2 models to TorchScript.
9 changes: 6 additions & 3 deletions detectron2/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

from detectron2.config import CfgNode as CN

from .caffe2_export import export_caffe2_detection_model
from .caffe2_export import export_onnx_model as export_onnx_model_impl
from .caffe2_export import run_and_save_graph
from .caffe2_inference import ProtobufDetectionModel
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
Expand Down Expand Up @@ -108,6 +105,8 @@ def export_caffe2(self):
Returns:
Caffe2Model
"""
from .caffe2_export import export_caffe2_detection_model

model, inputs = self._get_traceable()
predict_net, init_net = export_caffe2_detection_model(model, inputs)
return Caffe2Model(predict_net, init_net)
Expand All @@ -122,6 +121,8 @@ def export_onnx(self):
Returns:
onnx.ModelProto: an onnx model.
"""
from .caffe2_export import export_onnx_model as export_onnx_model_impl

model, inputs = self._get_traceable()
return export_onnx_model_impl(model, (inputs,))

Expand Down Expand Up @@ -239,6 +240,8 @@ def save_graph(self, output_file, inputs=None):
shape of every tensor. The shape information will be
saved together with the graph.
"""
from .caffe2_export import run_and_save_graph

if inputs is None:
save_graph(self._predict_net, output_file, op_only=False)
else:
Expand Down
154 changes: 154 additions & 0 deletions detectron2/export/torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import importlib.util
import os
import sys
import tempfile
from contextlib import contextmanager
from typing import Dict
import torch

# fmt: off
from detectron2.modeling.proposal_generator import RPN
# need an explicit import due to https://github.com/pytorch/pytorch/issues/38964
from detectron2.structures import Boxes, Instances # noqa F401

# fmt: on

_counter = 0


def export_torchscript_with_instances(model, fields):
"""
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
for torchscript to support it out of the box. This function is made to support scripting
a model that uses :class:`Instances`. It does the following:
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
but with all attributes been "static".
The attributes need to be statically declared in the ``fields`` argument.
2. Register ``new_Instances`` to torchscript, and force torchscript to
use it when trying to compile ``Instances``.
After this function, the process will be reverted. User should be able to script another model
using different fields.
Example:
Assume that ``Instances`` in the model consist of two attributes named
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
:class:`Tensor` respectively during inference. You can call this function like:
::
fields = {"proposal_boxes": "Boxes", "objectness_logits": "Tensor"}
torchscipt_model = export_torchscript_with_instances(model, fields)
Args:
model (nn.Module): The input model to be exported to torchscript.
fields (Dict[str, str]): Attribute names and corresponding type annotations that
``Instances`` will use in the model. Note that all attributes used in ``Instances``
need to be added, regarldess of whether they are inputs/outputs of the model.
Custom data type is not supported for now.
Returns:
torch.jit.ScriptModule: the input model in torchscript format
"""
with patch_instances(fields):

# Also add some other hacks for torchscript:
# boolean as dictionary keys is unsupported:
# https://github.com/pytorch/pytorch/issues/41449
# We annotate it this way to let torchscript interpret them as integers.
RPN.__annotations__["pre_nms_topk"] = Dict[int, int]
RPN.__annotations__["post_nms_topk"] = Dict[int, int]

scripted_model = torch.jit.script(model)
return scripted_model


@contextmanager
def patch_instances(fields):
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
) as f:
try:
cls_name, s = _gen_module(fields)
f.write(s)
f.flush()
f.close()

module = _import(f.name)
new_instances = getattr(module, cls_name)
_ = torch.jit.script(new_instances)

# let torchscript think Instances was scripted already
Instances.__torch_script_class__ = True
# let torchscript find new_instances when looking for the jit type of Instances
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
yield new_instances
finally:
try:
del Instances.__torch_script_class__
del Instances._jit_override_qualname
except AttributeError:
pass
sys.modules.pop(module.__name__)


# TODO: find a more automatic way to enable import of other classes
def _gen_imports():
imports_str = """
import torch
from torch import Tensor
import typing
from typing import *
from detectron2.structures import Boxes
"""
return imports_str


def _gen_class(fields):
def indent(level, s):
return " " * 4 * level + s

lines = []

global _counter
_counter += 1

cls_name = "Instances_patched{}".format(_counter)

lines.append(
f"""
class {cls_name}:
def __init__(self, image_size: Tuple[int, int]):
self.image_size = image_size
"""
)

for name, type_ in fields.items():
lines.append(indent(2, f"self.{name} = torch.jit.annotate(Optional[{type_}], None)"))
# TODO add getter/setter when @property is supported

return cls_name, os.linesep.join(lines)


def _gen_module(fields):
s = ""
s += _gen_imports()
cls_name, cls_def = _gen_class(fields)
s += cls_def
return cls_name, s


def _import(path):
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(
"{}{}".format(sys.modules[__name__].__name__, _counter), path
)
module = importlib.util.module_from_spec(spec)
sys.modules[module.__name__] = module
spec.loader.exec_module(module)
return module
2 changes: 1 addition & 1 deletion detectron2/modeling/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DefaultAnchorGenerator(nn.Module):
"Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks".
"""

box_dim: int = 4
box_dim: torch.jit.Final[int] = 4
"""
the dimension of each anchor box.
"""
Expand Down
9 changes: 3 additions & 6 deletions detectron2/modeling/proposal_generator/proposal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import logging
import math
from typing import List, Tuple
Expand All @@ -18,7 +17,7 @@ def find_top_rpn_proposals(
nms_thresh: float,
pre_nms_topk: int,
post_nms_topk: int,
min_box_size: int,
min_box_size: float,
training: bool,
):
"""
Expand Down Expand Up @@ -57,9 +56,7 @@ def find_top_rpn_proposals(
topk_proposals = []
level_ids = [] # #lvl Tensor, each of shape (topk,)
batch_idx = torch.arange(num_images, device=device)
for level_id, proposals_i, logits_i in zip(
itertools.count(), proposals, pred_objectness_logits
):
for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)):
Hi_Wi_A = logits_i.shape[1]
num_proposals_i = min(pre_nms_topk, Hi_Wi_A)

Expand All @@ -82,7 +79,7 @@ def find_top_rpn_proposals(
level_ids = cat(level_ids, dim=0)

# 3. For each image, run a per-level NMS, and choose topk results.
results = []
results: List[Instances] = []
for n, image_size in enumerate(image_sizes):
boxes = Boxes(topk_proposals[n])
scores_per_img = topk_scores[n]
Expand Down
35 changes: 22 additions & 13 deletions detectron2/modeling/proposal_generator/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
self.pre_nms_topk = {True: pre_nms_topk[0], False: pre_nms_topk[1]}
self.post_nms_topk = {True: post_nms_topk[0], False: post_nms_topk[1]}
self.nms_thresh = nms_thresh
self.min_box_size = min_box_size
self.min_box_size = float(min_box_size)
self.anchor_boundary_thresh = anchor_boundary_thresh
if isinstance(loss_weight, float):
loss_weight = {"loss_rpn_cls": loss_weight, "loss_rpn_loc": loss_weight}
Expand Down Expand Up @@ -264,8 +264,11 @@ def _subsample_labels(self, label):
label.scatter_(0, neg_idx, 0)
return label

@torch.jit.unused
@torch.no_grad()
def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
def label_and_sample_anchors(
self, anchors: List[Boxes], gt_instances: List[Instances]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
anchors (list[Boxes]): anchors for each feature map.
Expand Down Expand Up @@ -321,14 +324,15 @@ def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Inst
matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes

@torch.jit.unused
def losses(
self,
anchors,
anchors: List[Boxes],
pred_objectness_logits: List[torch.Tensor],
gt_labels: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor],
gt_boxes: List[torch.Tensor],
):
) -> Dict[str, torch.Tensor]:
"""
Return the losses from a set of RPN predictions and their associated ground-truth.
Expand Down Expand Up @@ -388,16 +392,18 @@ def losses(
reduction="sum",
)
normalizer = self.batch_size_per_image * num_images
return {
losses = {
"loss_rpn_cls": objectness_loss / normalizer,
"loss_rpn_loc": localization_loss / normalizer,
}
losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
return losses

def forward(
self,
images: ImageList,
features: Dict[str, torch.Tensor],
gt_instances: Optional[Instances] = None,
gt_instances: Optional[List[Instances]] = None,
):
"""
Args:
Expand Down Expand Up @@ -432,23 +438,23 @@ def forward(
]

if self.training:
assert gt_instances is not None, "RPN requires gt_instances in training!"
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
losses = self.losses(
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
)
losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
else:
losses = {}

proposals = self.predict_proposals(
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
)
return proposals, losses

@torch.no_grad()
# TODO: use torch.no_grad when torchscript supports it.
# https://github.com/pytorch/pytorch/pull/41371
def predict_proposals(
self,
anchors,
anchors: List[Boxes],
pred_objectness_logits: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor],
image_sizes: List[Tuple[int, int]],
Expand All @@ -465,19 +471,22 @@ def predict_proposals(
# The proposals are treated as fixed for approximate joint training with roi heads.
# This approach ignores the derivative w.r.t. the proposal boxes’ coordinates that
# are also network responses, so is approximate.
pred_objectness_logits = [t.detach() for t in pred_objectness_logits]
pred_anchor_deltas = [t.detach() for t in pred_anchor_deltas]
pred_proposals = self._decode_proposals(anchors, pred_anchor_deltas)
return find_top_rpn_proposals(
pred_proposals,
pred_objectness_logits,
image_sizes,
self.nms_thresh,
self.pre_nms_topk[self.training],
self.post_nms_topk[self.training],
# https://github.com/pytorch/pytorch/issues/41449
self.pre_nms_topk[int(self.training)],
self.post_nms_topk[int(self.training)],
self.min_box_size,
self.training,
)

def _decode_proposals(self, anchors, pred_anchor_deltas: List[torch.Tensor]):
def _decode_proposals(self, anchors: List[Boxes], pred_anchor_deltas: List[torch.Tensor]):
"""
Transform anchors into proposals by applying the predicted anchor deltas.
Expand Down
3 changes: 1 addition & 2 deletions detectron2/structures/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
keep = (widths > threshold) & (heights > threshold)
return keep

@torch.jit.unused
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]):
def __getitem__(self, item):
"""
Args:
item: int, slice, or a BoolTensor
Expand Down
Loading

0 comments on commit 68df102

Please sign in to comment.