Skip to content

Commit

Permalink
Variable annotation of exporting Densepose model to torchscript.
Browse files Browse the repository at this point in the history
Summary:
As described in facebookresearch#1444 , this pr contains variable annotation of export densepose model to torchscript.
Pull Request resolved: facebookresearch#1446

Reviewed By: bxiong1202

Differential Revision: D21631581

Pulled By: ppwwyyxx

fbshipit-source-id: 6a1bff126c1c9299924bc3bedc795aa035968b2d
  • Loading branch information
chenbohua3 authored and facebook-github-bot committed May 19, 2020
1 parent e5fb618 commit c9559e6
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 37 deletions.
4 changes: 3 additions & 1 deletion detectron2/layers/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from torchvision.ops import nms # BC-compat


def batched_nms(boxes, scores, idxs, iou_threshold):
def batched_nms(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
):
"""
Same as torchvision.ops.boxes.batched_nms, but safer.
"""
Expand Down
3 changes: 2 additions & 1 deletion detectron2/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
"""

import math
from typing import List
import torch
from torch.nn.modules.utils import _ntuple

TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])


def cat(tensors, dim=0):
def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
Expand Down
2 changes: 1 addition & 1 deletion detectron2/modeling/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.
anchors.append([x0, y0, x1, y1])
return torch.tensor(anchors)

def forward(self, features):
def forward(self, features: List[torch.Tensor]):
"""
Args:
features (list[Tensor]): list of backbone feature maps on which to generate anchors.
Expand Down
7 changes: 5 additions & 2 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import sys
from typing import List
import torch
from torch import nn
from torchvision.ops import RoIPool
Expand All @@ -10,7 +11,9 @@
__all__ = ["ROIPooler"]


def assign_boxes_to_levels(box_lists, min_level, max_level, canonical_box_size, canonical_level):
def assign_boxes_to_levels(
box_lists, min_level: int, max_level: int, canonical_box_size: int, canonical_level: int
):
"""
Map each box in `box_lists` to a feature map level index and return the assignment
vector.
Expand Down Expand Up @@ -173,7 +176,7 @@ def __init__(
assert canonical_box_size > 0
self.canonical_box_size = canonical_box_size

def forward(self, x, box_lists):
def forward(self, x: List[torch.Tensor], box_lists):
"""
Args:
x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
Expand Down
13 changes: 9 additions & 4 deletions detectron2/modeling/proposal_generator/rpn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from torch import nn

from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.structures import Boxes, Instances, pairwise_iou
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
from detectron2.utils.memory import retry_if_cuda_oom
from detectron2.utils.registry import Registry

Expand Down Expand Up @@ -88,7 +88,7 @@ def from_config(cls, cfg, input_shape):
), "Each level must have the same number of anchors per spatial position"
return {"in_channels": in_channels, "num_anchors": num_anchors[0], "box_dim": box_dim}

def forward(self, features):
def forward(self, features: List[torch.Tensor]):
"""
Args:
features (list[Tensor]): list of feature maps
Expand Down Expand Up @@ -224,7 +224,12 @@ def label_and_sample_anchors(self, anchors: List[Boxes], gt_instances: List[Inst
matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes

def forward(self, images, features, gt_instances=None):
def forward(
self,
images: ImageList,
features: Dict[str, torch.Tensor],
gt_instances: Optional[Instances] = None,
):
"""
Args:
images (ImageList): input images of length `N`
Expand Down
38 changes: 20 additions & 18 deletions detectron2/modeling/proposal_generator/rpn_outputs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import logging
from typing import List, Optional
import torch
import torch.nn.functional as F
from fvcore.nn import smooth_l1_loss

from detectron2.layers import batched_nms, cat
from detectron2.structures import Boxes, Instances
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.utils.events import get_event_storage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,14 +48,14 @@


def find_top_rpn_proposals(
proposals,
pred_objectness_logits,
images,
nms_thresh,
pre_nms_topk,
post_nms_topk,
min_box_side_len,
training,
proposals: List[torch.Tensor],
pred_objectness_logits: List[torch.Tensor],
images: ImageList,
nms_thresh: float,
pre_nms_topk: int,
post_nms_topk: int,
min_box_side_len: int,
training: bool,
):
"""
For each feature map, select the `pre_nms_topk` highest scoring proposals,
Expand Down Expand Up @@ -159,7 +161,7 @@ def find_top_rpn_proposals(


def rpn_losses(
gt_labels, gt_anchor_deltas, pred_objectness_logits, pred_anchor_deltas, smooth_l1_beta
gt_labels, gt_anchor_deltas, pred_objectness_logits, pred_anchor_deltas, smooth_l1_beta: float
):
"""
Args:
Expand Down Expand Up @@ -196,15 +198,15 @@ def rpn_losses(
class RPNOutputs(object):
def __init__(
self,
box2box_transform,
batch_size_per_image,
images,
pred_objectness_logits,
pred_anchor_deltas,
box2box_transform: Box2BoxTransform,
batch_size_per_image: int,
images: ImageList,
pred_objectness_logits: List[torch.Tensor],
pred_anchor_deltas: List[torch.Tensor],
anchors,
gt_labels=None,
gt_boxes=None,
smooth_l1_beta=0.0,
gt_labels: Optional[List[torch.Tensor]] = None,
gt_boxes: Optional[List[torch.Tensor]] = None,
smooth_l1_beta: float = 0.0,
):
"""
Args:
Expand Down
4 changes: 3 additions & 1 deletion detectron2/modeling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
__all__ = ["subsample_labels"]


def subsample_labels(labels, num_samples, positive_fraction, bg_label):
def subsample_labels(
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
):
"""
Return `num_samples` (or fewer, if not enough found)
random samples from `labels` which is a mixture of positives & negatives.
Expand Down
10 changes: 7 additions & 3 deletions projects/DensePose/densepose/densepose_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from enum import Enum
from typing import Iterable, Optional
from typing import Iterable, List, Optional, Tuple
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
Expand Down Expand Up @@ -130,7 +130,7 @@ def forward(self, features):
output = x
return output

def _get_layer_name(self, i):
def _get_layer_name(self, i: int):
layer_name = "body_conv_fcn{}".format(i + 1)
return layer_name

Expand Down Expand Up @@ -540,7 +540,11 @@ def build_densepose_data_filter(cfg):
return dp_filter


def densepose_inference(densepose_outputs, densepose_confidences, detections):
def densepose_inference(
densepose_outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
densepose_confidences: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
detections: List[Instances],
):
"""
Infer dense pose estimate based on outputs from the DensePose head
and detections. The estimate for each detection instance is stored in its
Expand Down
19 changes: 14 additions & 5 deletions projects/DensePose/densepose/roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import numpy as np
from typing import Dict
from typing import Dict, List, Optional
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn as nn
Expand All @@ -12,6 +12,7 @@
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.roi_heads import select_foreground_proposals
from detectron2.structures import ImageList, Instances

from .densepose_head import (
build_densepose_data_filter,
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(self, cfg, input_shape: Dict[str, ShapeSpec], in_features):
self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0)
weight_init.c2_msra_fill(self.predictor)

def forward(self, features):
def forward(self, features: List[torch.Tensor]):
for i, _ in enumerate(self.in_features):
if i == 0:
x = self.scale_heads[i](features[i])
Expand Down Expand Up @@ -122,7 +123,7 @@ def _init_densepose_head(self, cfg, input_shape):
)
self.densepose_losses = build_densepose_losses(cfg)

def _forward_densepose(self, features, instances):
def _forward_densepose(self, features: List[torch.Tensor], instances: List[Instances]):
"""
Forward logic of the densepose prediction branch.
Expand Down Expand Up @@ -181,15 +182,23 @@ def _forward_densepose(self, features, instances):
densepose_inference(densepose_outputs, confidences, instances)
return instances

def forward(self, images, features, proposals, targets=None):
def forward(
self,
images: ImageList,
features: Dict[str, torch.Tensor],
proposals: List[Instances],
targets: Optional[List[Instances]] = None,
):
instances, losses = super().forward(images, features, proposals, targets)
del targets, images

if self.training:
losses.update(self._forward_densepose(features, instances))
return instances, losses

def forward_with_given_boxes(self, features, instances):
def forward_with_given_boxes(
self, features: Dict[str, torch.Tensor], instances: List[Instances]
):
"""
Use the given boxes in `instances` to produce other (non-box) per-ROI outputs.
Expand Down
1 change: 0 additions & 1 deletion tests/modeling/test_roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from detectron2.config import get_cfg
from detectron2.layers import ShapeSpec
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.proposal_generator.build import build_proposal_generator
from detectron2.modeling.roi_heads import StandardROIHeads, build_roi_heads
from detectron2.structures import BitMasks, Boxes, ImageList, Instances, RotatedBoxes
Expand Down

0 comments on commit c9559e6

Please sign in to comment.