Skip to content

Commit

Permalink
RCNN w/ FPN, Sync BN & Disable static alloc for RCNN (dmlc#494)
Browse files Browse the repository at this point in the history
* number of gpus

* syncbn

* pylint

* resnetv1c

* merge

* indent

* unitest

* trigger

* norm args

* indent

* resnet v1d +0.5%

* style

* update docs

* fix args

* trigger build

* add test

* resolve conflict

* Add FPN model

* Add FPN train scripts

* Fix FPN error, Stay tuned

Training on VOC is still going on, I will report the result and log later.

* Revert "Sync from dmlc/master"

* Revert "Revert "Sync from dmlc/master""

* Update gluoncv/model_zoo/fpn/fpn.py

* Fix `FPN` Bugs

mAP on VOC07 is 58%, stay tuned.

* add faster_rcnn_fpn_resnet50_v1b model

* Update gluoncv/model_zoo/fpn/fpn.py

* Update gluoncv/model_zoo/fpn/fpn.py

* Create Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Update Readme.md

* Rename Readme.md to README.md

* Update Train and Eval script, Support Eval VOC12 Test.

* Update scripts/detection/fpn/eval_fpn_voc12.py

* Update scripts/detection/fpn/eval_fpn_voc12.py

* Update README.md

* Update README.md

* Update fpn.py

* Update gluoncv/model_zoo/model_zoo.py

* Update gluoncv/model_zoo/model_zoo.py

* default not to use static alloc to save memory, speed is not significantly impacted.
added dilated faster_rcnn_resnet50_v1b
added mask_rcnn_resnet101_v1b

* fix missing args

* small fix

* docs

* rm unneeded file

* rm debug log

* Faster RCNN with FPN

* rm unnecessary files

pylint

rm Non-ASCII

fix syntax

lint

rm from .fpn import *

stride => strides

rm syncbn in rpn

rm syncbn arg

mask rcnn arg fix

missing "s"

* rm 's' in anchor_generators

* old model compatibility

fix

* not using RPNHead to keep backward compatibility with old models

* _strides

* mask rcnn compatibility

* docs

* rm dilated faster rcnn

* mask rcnn w/ fpn

* rm undefined functions

* change default roi mode to 'align'

* trigger build

* change name of the fpn networks

* model store update

* Fix typo (dmlc#622)

* Improve custo coco compatible detection dataset (dmlc#624)

* coco det improve for custom datasets

* allow flexible image path parser

* fix pycocotools _isArrayLike

* better comment

* clean

* Add assertions for invalid class names for VOCDetection (dmlc#614)

* Add assertions for invalid class names

* Add assertions for invalid class names (revision1)

* Add assertions/warnings for invalid class names (revision2)

* Add assertions/warnings for invalid class names (revision3)

* Add assertions/warnings for invalid class names (revision4)

* add detection paper (dmlc#628)

* add bibtex

* rephrase

* update bibtex

* Update PSP Params (dmlc#629)

* update psp params

* update with pin-device_id (dmlc#630)

* sync bn faster rcnn

* pylint

* change roi from 7 to 14, since the last fpn model we trained use 14

* add pretrained faster rcnn fpn bn

* Update model_zoo.py

* Update model_zoo.py
  • Loading branch information
Jerryzcn authored and zhreshold committed Mar 19, 2019
1 parent 960a107 commit 4c624de
Show file tree
Hide file tree
Showing 16 changed files with 1,161 additions and 252 deletions.
111 changes: 89 additions & 22 deletions gluoncv/data/transforms/presets/rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Transforms for RCNN series."""
from __future__ import absolute_import

import copy
from random import randint

import mxnet as mx

from .. import bbox as tbbox
from .. import image as timage
from .. import mask as tmask
Expand All @@ -10,6 +14,7 @@
'FasterRCNNDefaultTrainTransform', 'FasterRCNNDefaultValTransform',
'MaskRCNNDefaultTrainTransform', 'MaskRCNNDefaultValTransform']


def transform_test(imgs, short=600, max_size=1000, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)):
"""A util function to transform all images to tensors as network input by applying
Expand Down Expand Up @@ -56,6 +61,7 @@ def transform_test(imgs, short=600, max_size=1000, mean=(0.485, 0.456, 0.406),
return tensors[0], origs[0]
return tensors, origs


def load_test(filenames, short=600, max_size=1000, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)):
"""A util function to load all images, transform them to tensor by applying
Expand Down Expand Up @@ -95,8 +101,9 @@ class FasterRCNNDefaultTrainTransform(object):
Parameters
----------
short : int, default is 600
short : int/tuple, default is 600
Resize image shorter side to ``short``.
Resize the shorter side of the image randomly within the given range, if it is a tuple.
max_size : int, default is 1000
Make sure image longer side is smaller than ``max_size``.
net : mxnet.gluon.HybridBlock, optional
Expand Down Expand Up @@ -128,28 +135,40 @@ class FasterRCNNDefaultTrainTransform(object):
flip_p : float, default is 0.5
Probability to flip horizontally, by default is 0.5 for random horizontal flip.
You may set it to 0 to disable random flip or 1 to force flip.
ashape : int, default is 128
Defines shape of pre generated anchors for target generation
multi_stage : boolean, default is False
Whether the network output multi stage features.
"""

def __init__(self, short=600, max_size=1000, net=None, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), box_norm=(1., 1., 1., 1.),
num_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3,
pos_ratio=0.5, flip_p=0.5, **kwargs):
pos_ratio=0.5, flip_p=0.5, ashape=128, multi_stage=False, **kwargs):
self._short = short
self._max_size = max_size
self._mean = mean
self._std = std
self._anchors = None
self._multi_stage = multi_stage
self._random_resize = isinstance(self._short, (tuple, list))
self._flip_p = flip_p
if net is None:
return

# use fake data to generate fixed anchors for target generation
ashape = 128
anchors = [] # [P2, P3, P4, P5]
# in case network has reset_ctx to gpu
anchor_generator = copy.deepcopy(net.rpn.anchor_generator)
anchor_generator.collect_params().reset_ctx(None)
anchors = anchor_generator(
mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
if self._multi_stage:
for ag in anchor_generator:
anchor = ag(mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
ashape = max(ashape // 2, 16)
anchors.append(anchor)
else:
anchors = anchor_generator(
mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
self._anchors = anchors
# record feature extractor for infer_shape
if not hasattr(net, 'features'):
Expand All @@ -165,7 +184,11 @@ def __call__(self, src, label):
"""Apply transform to training image/label."""
# resize shorter side but keep in max_size
h, w, _ = src.shape
img = timage.resize_short_within(src, self._short, self._max_size, interp=1)
if self._random_resize:
short = randint(self._short[0], self._short[1])
else:
short = self._short
img = timage.resize_short_within(src, short, self._max_size, interp=1)
bbox = tbbox.resize(label, (w, h), (img.shape[1], img.shape[0]))

# random horizontal flip
Expand All @@ -182,11 +205,23 @@ def __call__(self, src, label):

# generate RPN target so cpu workers can help reduce the workload
# feat_h, feat_w = (img.shape[1] // self._stride, img.shape[2] // self._stride)
oshape = self._feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0]
anchor = self._anchors[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))
gt_bboxes = mx.nd.array(bbox[:, :4])
cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor, img.shape[2], img.shape[1])
if self._multi_stage:
oshapes = []
anchor_targets = []
for feat_sym in self._feat_sym:
oshapes.append(feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0])
for anchor, oshape in zip(self._anchors, oshapes):
anchor = anchor[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))
anchor_targets.append(anchor)
anchor_targets = mx.nd.concat(*anchor_targets, dim=0)
cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor_targets, img.shape[2], img.shape[1])
else:
oshape = self._feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0]
anchor = self._anchors[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))
cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor, img.shape[2], img.shape[1])
return img, bbox.astype(img.dtype), cls_target, box_target, box_mask


Expand All @@ -205,6 +240,7 @@ class FasterRCNNDefaultValTransform(object):
Standard deviation to be divided from image. Default is [0.229, 0.224, 0.225].
"""

def __init__(self, short=600, max_size=1000,
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self._mean = mean
Expand All @@ -231,8 +267,9 @@ class MaskRCNNDefaultTrainTransform(object):
Parameters
----------
short : int, default is 600
short : int/tuple, default is 600
Resize image shorter side to ``short``.
Resize the shorter side of the image randomly within the given range, if it is a tuple.
max_size : int, default is 1000
Make sure image longer side is smaller than ``max_size``.
net : mxnet.gluon.HybridBlock, optional
Expand Down Expand Up @@ -261,27 +298,39 @@ class MaskRCNNDefaultTrainTransform(object):
pos_ratio : float, default is 0.5
``pos_ratio`` defines how many positive samples (``pos_ratio * num_sample``) is
to be sampled.
ashape : int, default is 128
Defines shape of pre generated anchors for target generation
multi_stage : boolean, default is False
Whether the network output multi stage features.
"""

def __init__(self, short=600, max_size=1000, net=None, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), box_norm=(1., 1., 1., 1.),
num_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3,
pos_ratio=0.5, **kwargs):
pos_ratio=0.5, ashape=128, multi_stage=False, **kwargs):
self._short = short
self._max_size = max_size
self._mean = mean
self._std = std
self._anchors = None
self._multi_stage = multi_stage
self._random_resize = isinstance(self._short, (tuple, list))
if net is None:
return

# use fake data to generate fixed anchors for target generation
ashape = 128
anchors = [] # [P2, P3, P4, P5]
# in case network has reset_ctx to gpu
anchor_generator = copy.deepcopy(net.rpn.anchor_generator)
anchor_generator.collect_params().reset_ctx(None)
anchors = anchor_generator(
mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
if self._multi_stage:
for ag in anchor_generator:
anchor = ag(mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
ashape = max(ashape // 2, 16)
anchors.append(anchor)
else:
anchors = anchor_generator(
mx.nd.zeros((1, 3, ashape, ashape))).reshape((1, 1, ashape, ashape, -1))
self._anchors = anchors
# record feature extractor for infer_shape
if not hasattr(net, 'features'):
Expand All @@ -297,7 +346,11 @@ def __call__(self, src, label, segm):
"""Apply transform to training image/label."""
# resize shorter side but keep in max_size
h, w, _ = src.shape
img = timage.resize_short_within(src, self._short, self._max_size, interp=1)
if self._random_resize:
short = randint(self._short[0], self._short[1])
else:
short = self._short
img = timage.resize_short_within(src, short, self._max_size, interp=1)
bbox = tbbox.resize(label, (w, h), (img.shape[1], img.shape[0]))
segm = [tmask.resize(polys, (w, h), (img.shape[1], img.shape[0])) for polys in segm]

Expand All @@ -321,11 +374,24 @@ def __call__(self, src, label, segm):

# generate RPN target so cpu workers can help reduce the workload
# feat_h, feat_w = (img.shape[1] // self._stride, img.shape[2] // self._stride)
oshape = self._feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0]
anchor = self._anchors[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))
gt_bboxes = mx.nd.array(bbox[:, :4])
cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor, img.shape[2], img.shape[1])
if self._multi_stage:
oshapes = []
anchor_targets = []
for feat_sym in self._feat_sym:
oshapes.append(feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0])
for anchor, oshape in zip(self._anchors, oshapes):
anchor = anchor[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))
anchor_targets.append(anchor)
anchor_targets = mx.nd.concat(*anchor_targets, dim=0)
cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor_targets, img.shape[2], img.shape[1])
else:
oshape = self._feat_sym.infer_shape(data=(1, 3, img.shape[1], img.shape[2]))[1][0]
anchor = self._anchors[:, :, :oshape[2], :oshape[3], :].reshape((-1, 4))

cls_target, box_target, box_mask = self._target_generator(
gt_bboxes, anchor, img.shape[2], img.shape[1])
return img, bbox.astype(img.dtype), masks, cls_target, box_target, box_mask


Expand All @@ -344,6 +410,7 @@ class MaskRCNNDefaultValTransform(object):
Standard deviation to be divided from image. Default is [0.229, 0.224, 0.225].
"""

def __init__(self, short=600, max_size=1000,
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self._mean = mean
Expand Down
1 change: 1 addition & 0 deletions gluoncv/model_zoo/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from __future__ import absolute_import

from .faster_rcnn import *
from .rcnn_target import RCNNTargetGenerator, RCNNTargetSampler
Loading

0 comments on commit 4c624de

Please sign in to comment.