Skip to content

Commit

Permalink
Merge pull request tensorflow#3846 from pkulzc/master
Browse files Browse the repository at this point in the history
Internal changes for object detection
  • Loading branch information
pkulzc authored Apr 3, 2018
2 parents c3b2660 + 143464d commit abd5042
Show file tree
Hide file tree
Showing 40 changed files with 1,021 additions and 240 deletions.
11 changes: 11 additions & 0 deletions research/object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:l291WsrB-hQJ:scholar.go

* Jonathan Huang, github: [jch1](https://github.com/jch1)
* Vivek Rathod, github: [tombstone](https://github.com/tombstone)
* Ronny Votel, github: [ronnyvotel](https://github.com/ronnyvotel)
* Derek Chow, github: [derekjchow](https://github.com/derekjchow)
* Chen Sun, github: [jesu9](https://github.com/jesu9)
* Menglong Zhu, github: [dreamdragon](https://github.com/dreamdragon)
Expand Down Expand Up @@ -89,6 +90,16 @@ reporting an issue.

## Release information

### April 2, 2018

Supercharge your mobile phones with the next generation mobile object detector!
We are adding support for MobileNet V2 with SSDLite presented in
[MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381).
This model is 35% faster than Mobilenet V1 SSD on a Google Pixel phone CPU (200ms vs. 270ms) at the same accuracy.
Along with the model definition, we are also releasing a model checkpoint trained on the COCO dataset.

<b>Thanks to contributors</b>: Menglong Zhu, Mark Sandler, Zhichao Lu, Vivek Rathod, Jonathan Huang

### February 9, 2018

We now support instance segmentation!! In this API update we support a number of instance segmentation models similar to those discussed in the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). For further details refer to
Expand Down
35 changes: 29 additions & 6 deletions research/object_detection/builders/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
Expand All @@ -55,6 +56,8 @@
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_nas':
frcnn_nas.FasterRCNNNASFeatureExtractor,
'faster_rcnn_pnas':
frcnn_pnas.FasterRCNNPNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2':
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
'faster_rcnn_inception_v2':
Expand Down Expand Up @@ -95,13 +98,19 @@ def build(model_config, is_training, add_summaries=True):


def _build_ssd_feature_extractor(feature_extractor_config, is_training,
reuse_weights=None):
reuse_weights=None,
inplace_batchnorm_update=False):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
ssd_meta_arch.SSDFeatureExtractor based on config.
Expand All @@ -126,7 +135,8 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
return feature_extractor_class(is_training, depth_multiplier, min_depth,
pad_to_multiple, conv_hyperparams,
batch_norm_trainable, reuse_weights,
use_explicit_padding, use_depthwise)
use_explicit_padding, use_depthwise,
inplace_batchnorm_update)


def _build_ssd_model(ssd_config, is_training, add_summaries):
Expand All @@ -140,15 +150,18 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
Returns:
SSDMetaArch based on the config.
Raises:
ValueError: If ssd_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = ssd_config.num_classes

# Feature extractor
feature_extractor = _build_ssd_feature_extractor(ssd_config.feature_extractor,
is_training)
feature_extractor = _build_ssd_feature_extractor(
feature_extractor_config=ssd_config.feature_extractor,
is_training=is_training,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)

box_coder = box_coder_builder.build(ssd_config.box_coder)
matcher = matcher_builder.build(ssd_config.matcher)
Expand Down Expand Up @@ -194,21 +207,29 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):


def _build_faster_rcnn_feature_extractor(
feature_extractor_config, is_training, reuse_weights=None):
feature_extractor_config, is_training, reuse_weights=None,
inplace_batchnorm_update=False):
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Args:
feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
faster_rcnn.proto.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
if inplace_batchnorm_update:
raise ValueError('inplace batchnorm updates not supported.')
feature_type = feature_extractor_config.type
first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride)
Expand Down Expand Up @@ -238,6 +259,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
Returns:
FasterRCNNMetaArch based on the config.
Raises:
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
model_class_map).
Expand All @@ -246,7 +268,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training)
frcnn_config.feature_extractor, is_training,
frcnn_config.inplace_batchnorm_update)

number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build(
Expand Down
70 changes: 70 additions & 0 deletions research/object_detection/builders/model_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
Expand Down Expand Up @@ -297,6 +298,7 @@ def test_create_ssd_resnet_v1_fpn_model_from_config(self):
def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
inplace_batchnorm_update: true
feature_extractor {
type: 'ssd_mobilenet_v1'
conv_hyperparams {
Expand Down Expand Up @@ -519,6 +521,7 @@ def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
def test_create_faster_rcnn_resnet_v1_models_from_config(self):
model_text_proto = """
faster_rcnn {
inplace_batchnorm_update: true
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
Expand Down Expand Up @@ -726,6 +729,73 @@ def test_create_faster_rcnn_nas_model_from_config(self):
model._feature_extractor,
frcnn_nas.FasterRCNNNASFeatureExtractor)

def test_create_faster_rcnn_pnas_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_pnas'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(
model._feature_extractor,
frcnn_pnas.FasterRCNNPNASFeatureExtractor)

def test_create_faster_rcnn_inception_resnet_v2_model_from_config(self):
model_text_proto = """
faster_rcnn {
Expand Down
11 changes: 8 additions & 3 deletions research/object_detection/core/box_list_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops

from object_detection.core import box_list
from object_detection.core import box_list_ops
Expand Down Expand Up @@ -509,9 +510,13 @@ def test_sort_by_field_invalid_inputs(self):
with self.assertRaises(ValueError):
box_list_ops.sort_by_field(boxes, 'misc')

with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
'Incorrect field size'):
sess.run(box_list_ops.sort_by_field(boxes, 'weights').get())
if ops._USE_C_API:
with self.assertRaises(ValueError):
box_list_ops.sort_by_field(boxes, 'weights')
else:
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
'Incorrect field size'):
sess.run(box_list_ops.sort_by_field(boxes, 'weights').get())

def test_visualize_boxes_in_image(self):
image = tf.zeros((6, 4, 3))
Expand Down
6 changes: 5 additions & 1 deletion research/object_detection/core/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,11 @@ def resize_masks_branch():
return new_masks

def reshape_masks_branch():
new_masks = tf.reshape(masks, [0, new_size[0], new_size[1]])
# The shape function will be computed for both branches of the
# condition, regardless of which branch is actually taken. Make sure
# that we don't trigger an assertion in the shape function when trying
# to reshape a non empty tensor into an empty one.
new_masks = tf.reshape(masks, [-1, new_size[0], new_size[1]])
return new_masks

masks = tf.cond(num_instances > 0, resize_masks_branch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ cd ${SCRATCH_DIR}
# Download the images.
BASE_IMAGE_URL="http://images.cocodataset.org/zips"

# TRAIN_IMAGE_FILE="train2017.zip"
TRAIN_IMAGE_FILE="train2017.zip"
download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE}
TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2017"

Expand All @@ -91,7 +91,7 @@ download_and_unzip ${BASE_IMAGE_INFO_URL} ${IMAGE_INFO_FILE}

TESTDEV_ANNOTATIONS_FILE="${SCRATCH_DIR}/annotations/image_info_test-dev2017.json"

# # Build TFRecords of the image data.
# Build TFRecords of the image data.
cd "${CURRENT_DIR}"
python object_detection/dataset_tools/create_coco_tf_record.py \
--logtostderr \
Expand Down
4 changes: 3 additions & 1 deletion research/object_detection/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def visualize_detection_results(result_dict,
data corresponding to each image being evaluated. The following keys
are required:
'original_image': a numpy array representing the image with shape
[1, height, width, 3]
[1, height, width, 3] or [1, height, width, 1]
'detection_boxes': a numpy array of shape [N, 4]
'detection_scores': a numpy array of shape [N]
'detection_classes': a numpy array of shape [N]
Expand Down Expand Up @@ -133,6 +133,8 @@ def visualize_detection_results(result_dict,
category_index = label_map_util.create_category_index(categories)

image = np.squeeze(result_dict[input_fields.original_image], axis=0)
if image.shape[2] == 1: # If one channel image, repeat in RGB.
image = np.tile(image, [1, 1, 3])
detection_boxes = result_dict[detection_fields.detection_boxes]
detection_scores = result_dict[detection_fields.detection_scores]
detection_classes = np.int32((result_dict[
Expand Down
14 changes: 12 additions & 2 deletions research/object_detection/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,24 @@ def _extract_predictions_and_losses(model,
if fields.InputDataFields.groundtruth_group_of in input_dict:
groundtruth[fields.InputDataFields.groundtruth_group_of] = (
input_dict[fields.InputDataFields.groundtruth_group_of])
groundtruth_masks_list = None
if fields.DetectionResultFields.detection_masks in detections:
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = (
input_dict[fields.InputDataFields.groundtruth_instance_masks])
groundtruth_masks_list = [
input_dict[fields.InputDataFields.groundtruth_instance_masks]]
groundtruth_keypoints_list = None
if fields.DetectionResultFields.detection_keypoints in detections:
groundtruth[fields.InputDataFields.groundtruth_keypoints] = (
input_dict[fields.InputDataFields.groundtruth_keypoints])
groundtruth_keypoints_list = [
input_dict[fields.InputDataFields.groundtruth_keypoints]]
label_id_offset = 1
model.provide_groundtruth(
[input_dict[fields.InputDataFields.groundtruth_boxes]],
[tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes]
- label_id_offset, depth=model.num_classes)])
- label_id_offset, depth=model.num_classes)],
groundtruth_masks_list, groundtruth_keypoints_list)
losses_dict.update(model.loss(prediction_dict, true_image_shapes))

result_dict = eval_util.result_dict_for_single_example(
Expand Down Expand Up @@ -205,7 +215,7 @@ def _process_batch(tensor_dict, sess, batch_index, counters,
except tf.errors.InvalidArgumentError:
logging.info('Skipping image')
counters['skipped'] += 1
return {}
return {}, {}
global_step = tf.train.global_step(sess, tf.train.get_global_step())
if batch_index < eval_config.num_visualizations:
tag = 'image-{}'.format(batch_index)
Expand Down
5 changes: 4 additions & 1 deletion research/object_detection/g3doc/detection_model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ In the table below, we list each such pre-trained model including:
aware that these timings depend highly on one's specific hardware
configuration (these timings were performed using an Nvidia
GeForce GTX TITAN X card) and should be treated more as relative timings in
many cases.
many cases. Also note that desktop GPU timing does not always reflect mobile
run time. For example Mobilenet V2 is faster on mobile devices than Mobilenet
V1, but is slightly slower on desktop GPU.
* detector performance on subset of the COCO validation set or Open Images test split as measured by the dataset-specific mAP measure.
Here, higher is better, and we only report bounding box mAP rounded to the
nearest integer.
Expand Down Expand Up @@ -68,6 +70,7 @@ Some remarks on frozen inference graphs:
| Model name | Speed (ms) | COCO mAP[^1] | Outputs |
| ------------ | :--------------: | :--------------: | :-------------: |
| [ssd_mobilenet_v1_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz) | 30 | 21 | Boxes |
| [ssd_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz) | 31 | 22 | Boxes |
| [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz) | 42 | 24 | Boxes |
| [faster_rcnn_inception_v2_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz) | 58 | 28 | Boxes |
| [faster_rcnn_resnet50_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz) | 89 | 30 | Boxes |
Expand Down
2 changes: 1 addition & 1 deletion research/object_detection/g3doc/running_pets.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ environment variable below:
export YOUR_GCS_BUCKET=${YOUR_GCS_BUCKET}
```

It is also possible to run locally by following
It is also possible to run locally by following
[the running locally instructions](running_locally.md).

## Installing Tensorflow and the Tensorflow Object Detection API
Expand Down
Loading

0 comments on commit abd5042

Please sign in to comment.