Skip to content

Commit

Permalink
add NASnet feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
tombstone committed Oct 29, 2017
1 parent c839310 commit 3237c08
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 0 deletions.
2 changes: 2 additions & 0 deletions research/object_detection/builders/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@

# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_nas':
frcnn_nas.FasterRCNNNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2':
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
'faster_rcnn_inception_v2':
Expand Down
68 changes: 68 additions & 0 deletions research/object_detection/builders/model_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
Expand Down Expand Up @@ -412,6 +413,73 @@ def test_create_faster_rcnn_resnet101_with_mask_prediction_enabled(self):
model = model_builder.build(model_proto, is_training=True)
self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)

def test_create_faster_rcnn_nas_model_from_config(self):
model_text_proto = """
faster_rcnn {
num_classes: 3
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 600
max_dimension: 1024
}
}
feature_extractor {
type: 'faster_rcnn_nas'
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride: 16
width_stride: 16
}
}
first_stage_box_predictor_conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
initial_crop_size: 17
maxpool_kernel_size: 1
maxpool_stride: 1
second_stage_box_predictor {
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.01
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
score_converter: SOFTMAX
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
self.assertIsInstance(
model._feature_extractor,
frcnn_nas.FasterRCNNNASFeatureExtractor)

def test_create_faster_rcnn_inception_resnet_v2_model_from_config(self):
model_text_proto = """
faster_rcnn {
Expand Down
23 changes: 23 additions & 0 deletions research/object_detection/models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,29 @@ py_test(
],
)

py_test(
name = "faster_rcnn_nas_feature_extractor_test",
srcs = [
"faster_rcnn_nas_feature_extractor_test.py",
],
deps = [
":faster_rcnn_nas_feature_extractor",
"//tensorflow",
],
)

py_library(
name = "faster_rcnn_nas_feature_extractor",
srcs = [
"faster_rcnn_nas_feature_extractor.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/slim:nasnet",
],
)

py_library(
name = "faster_rcnn_inception_resnet_v2_feature_extractor",
srcs = [
Expand Down
Loading

0 comments on commit 3237c08

Please sign in to comment.