Skip to content

Commit

Permalink
copy uvo det code
Browse files Browse the repository at this point in the history
  • Loading branch information
jozhang97 committed Oct 29, 2021
1 parent db256a1 commit e6443c0
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
checkpoints
models
data
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions configs_uvo
4 changes: 3 additions & 1 deletion mmdet/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from .region_assigner import RegionAssigner
from .sim_ota_assigner import SimOTAAssigner
from .uniform_assigner import UniformAssigner
from .rpn_sim_ota_assigner import RPN_SimOTAAssigner

__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner'
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
'RPN_SimOTAAssigner',
]
262 changes: 262 additions & 0 deletions mmdet/core/bbox/assigners/rpn_sim_ota_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn.functional as F

from ..builder import BBOX_ASSIGNERS
from ..iou_calculators import bbox_overlaps
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


@BBOX_ASSIGNERS.register_module()
class RPN_SimOTAAssigner(BaseAssigner):
"""Computes matching between predictions and ground truth.
Args:
center_radius (int | float, optional): Ground truth center size
to judge whether a prior is in center. Default 2.5.
candidate_topk (int, optional): The candidate top-k which used to
get top-k ious to calculate dynamic-k. Default 10.
iou_weight (int | float, optional): The scale factor for regression
iou cost. Default 3.0.
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
"""

def __init__(self,
center_radius=2.5,
candidate_topk=10,
iou_weight=3.0,
cls_weight=1.0):
self.center_radius = center_radius
self.candidate_topk = candidate_topk
self.iou_weight = iou_weight
self.cls_weight = cls_weight

def assign(self,
pred_scores,
bboxes,
num_level_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None,
eps=1e-7,
use_sqrt=True):

"""Assign gt to priors using SimOTA. It will switch to CPU mode when
GPU is out of memory.
Args:
pred_scores (Tensor): Classification scores of one image,
a 2D-Tensor with shape [num_priors, num_classes]
bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
[num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth labels of one image, a Tensor
with shape [num_gts].
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
eps (float): A value added to the denominator for numerical
stability. Default 1e-7.
Returns:
assign_result (obj:`AssignResult`): The assigned result.
"""
try:
assign_result = self._assign(pred_scores, bboxes,
gt_bboxes, gt_labels,
gt_bboxes_ignore, eps,
use_sqrt)
return assign_result
except RuntimeError:
origin_device = pred_scores.device
warnings.warn('OOM RuntimeError is raised due to the huge memory '
'cost during label assignment. CPU mode is applied '
'in this batch. If you want to avoid this issue, '
'try to reduce the batch size or image size.')
torch.cuda.empty_cache()

pred_scores = pred_scores.cpu()
bboxes = bboxes.cpu()
gt_bboxes = gt_bboxes.cpu().float()
#gt_labels = gt_labels.cpu()

assign_result = self._assign(pred_scores, bboxes,
gt_bboxes, gt_labels,
gt_bboxes_ignore, eps,
use_sqrt)
assign_result.gt_inds = assign_result.gt_inds.to(origin_device)
assign_result.max_overlaps = assign_result.max_overlaps.to(
origin_device)
assign_result.labels = assign_result.labels.to(origin_device)

return assign_result

def _assign(self,
pred_scores,
bboxes,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
eps=1e-7,
use_sqrt=True):
"""Assign gt to priors using SimOTA.
Args:
pred_scores (Tensor): Classification scores of one image,
a 2D-Tensor with shape [num_priors, num_classes]
decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
[num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth labels of one image, a Tensor
with shape [num_gts].
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
eps (float): A value added to the denominator for numerical
stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
INF = 100000000
num_gt = gt_bboxes.size(0)
num_bboxes = bboxes.size(0)

# assign 0 by default
assigned_gt_inds = bboxes.new_full((num_bboxes, ),
0,
dtype=torch.long)
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = bboxes.new_zeros((num_bboxes, ))
if num_gt == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = bboxes.new_full((num_bboxes, ),
-1,
dtype=torch.long)
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)

valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
bboxes, gt_bboxes)

valid_bbox = bboxes[valid_mask]
valid_pred_scores = pred_scores[valid_mask]
num_valid = valid_bbox.size(0)

pairwise_ious = bbox_overlaps(valid_bbox, gt_bboxes)
iou_cost = -torch.log(pairwise_ious + eps)

gt_labels = pred_scores.new_full((num_gt, ), 1).long()
gt_onehot_label = pred_scores.new_full((num_valid, num_gt, 1), 1).float()

valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
if use_sqrt:
cls_cost = F.binary_cross_entropy(
valid_pred_scores.sqrt_().detach(),
gt_onehot_label,
reduction='none').sum(-1)
else:
cls_cost = F.binary_cross_entropy(
valid_pred_scores.pow(1/3).detach(),
gt_onehot_label,
reduction='none').sum(-1)

cost_matrix = (
cls_cost * self.cls_weight + iou_cost * self.iou_weight +
(~is_in_boxes_and_center) * INF)

matched_pred_ious, matched_gt_inds = \
self.dynamic_k_matching(
cost_matrix, pairwise_ious, num_gt, valid_mask)

# convert to AssignResult format
assigned_gt_inds[valid_mask] = matched_gt_inds + 1
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
-INF,
dtype=torch.float32)
max_overlaps[valid_mask] = matched_pred_ious
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)

def get_in_gt_and_in_center_info(self, bboxes, gt_bboxes):
num_gt = gt_bboxes.size(0)

repeated_x = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
repeated_x = repeated_x.unsqueeze(1).repeat(1, num_gt)
repeated_y = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
repeated_y = repeated_y.unsqueeze(1).repeat(1, num_gt)
#repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
#repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
### Modified to self-adapted sampling --- the center size depends on the size of the gt boxes

# is prior centers in gt bboxes, shape: [n_prior, n_gt]
l_ = repeated_x - gt_bboxes[:, 0]
t_ = repeated_y - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - repeated_x
b_ = gt_bboxes[:, 3] - repeated_y

deltas = torch.stack([l_, t_, r_, b_], dim=1)
is_in_gts = deltas.min(dim=1).values > 0
is_in_gts_all = is_in_gts.sum(dim=1) > 0

# is prior centers in gt centers
gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
ct_box_l = gt_cxs - self.center_radius * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) #repeated_stride_x
ct_box_t = gt_cys - self.center_radius * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) #repeated_stride_y
ct_box_r = gt_cxs + self.center_radius * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) #repeated_stride_x
ct_box_b = gt_cys + self.center_radius * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) #repeated_stride_y

cl_ = repeated_x - ct_box_l
ct_ = repeated_y - ct_box_t
cr_ = ct_box_r - repeated_x
cb_ = ct_box_b - repeated_y

ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
is_in_cts = ct_deltas.min(dim=1).values > 0
is_in_cts_all = is_in_cts.sum(dim=1) > 0

# in boxes or in centers, shape: [num_priors]
is_in_gts_or_centers = is_in_gts_all | is_in_cts_all

# both in boxes and centers, shape: [num_fg, num_gt]
is_in_boxes_and_centers = (
is_in_gts[is_in_gts_or_centers, :]
& is_in_cts[is_in_gts_or_centers, :])
return is_in_gts_or_centers, is_in_boxes_and_centers

def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
matching_matrix = torch.zeros_like(cost)
# select candidate topk ious for dynamic-k calculation
topk_ious, _ = torch.topk(pairwise_ious, min(self.candidate_topk, pairwise_ious.size(0)), dim=0)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
tmp_k = min(len(cost[:, gt_idx]), dynamic_ks[gt_idx].item())
_, pos_idx = torch.topk(
cost[:, gt_idx], k=tmp_k, largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1.0

del topk_ious, dynamic_ks, pos_idx

prior_match_gt_mask = matching_matrix.sum(1) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0.0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0.0
valid_mask[valid_mask.clone()] = fg_mask_inboxes

matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
matched_pred_ious = (matching_matrix *
pairwise_ious).sum(1)[fg_mask_inboxes]
return matched_pred_ious, matched_gt_inds
7 changes: 4 additions & 3 deletions mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .autoassign_head import AutoAssignHead
from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .cascade_rpn_head import CascadeRPNHead #, StageCascadeRPNHead
from .centernet_head import CenterNetHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
Expand Down Expand Up @@ -35,6 +35,7 @@
from .yolo_head import YOLOV3Head
from .yolof_head import YOLOFHead
from .yolox_head import YOLOXHead
from .uvo_rpn_head import UVORPNHead

__all__ = [
'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
Expand All @@ -43,9 +44,9 @@
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', # 'StageCascadeRPNHead',
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead'
'DecoupledSOLOLightHead', 'UVORPNHead',
]
2 changes: 1 addition & 1 deletion mmdet/models/dense_heads/cascade_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def forward(self, x, offset):
return x


@HEADS.register_module()
# @HEADS.register_module()
class StageCascadeRPNHead(RPNHead):
"""Stage of CascadeRPNHead.
Expand Down
16 changes: 16 additions & 0 deletions tools/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from mmdet.apis import init_detector, inference_detector

# config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
config_file = 'configs_uvo/uvo/swin_l_carafe_simota_focal_giou_iouhead_tower_dcn_coco_384_uvo_finetune.py'

# download the checkpoint from model zoo and put it in `checkpoints/`
# url: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
# checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
checkpoint_file = None
device = 'cuda:0'
# init a detector
model = init_detector(config_file, checkpoint_file, device=device)
# inference the demo image
import ipdb; ipdb.set_trace(context=21)
ret = inference_detector(model, 'demo/demo.jpg')
print(type(ret))

0 comments on commit e6443c0

Please sign in to comment.