Skip to content

Commit

Permalink
Modify the ssd meta arch to allow the option of not adding an implici…
Browse files Browse the repository at this point in the history
…t background class.

PiperOrigin-RevId: 192529600
  • Loading branch information
pkulzc committed Apr 13, 2018
1 parent a60dd98 commit eccae44
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 21 deletions.
22 changes: 16 additions & 6 deletions research/object_detection/builders/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@
}


def build(model_config, is_training, add_summaries=True):
def build(model_config, is_training, add_summaries=True,
add_background_class=True):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation. Ignored in the case of faster_rcnn.
Returns:
DetectionModel based on the config.
Expand All @@ -90,7 +94,8 @@ def build(model_config, is_training, add_summaries=True):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training, add_summaries)
return _build_ssd_model(model_config.ssd, is_training, add_summaries,
add_background_class)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
Expand Down Expand Up @@ -133,15 +138,19 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
override_base_feature_extractor_hyperparams)


def _build_ssd_model(ssd_config, is_training, add_summaries):
def _build_ssd_model(ssd_config, is_training, add_summaries,
add_background_class=True):
"""Builds an SSD detection model based on the model config.
Args:
ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation.
Returns:
SSDMetaArch based on the config.
Expand Down Expand Up @@ -198,7 +207,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class)


def _build_faster_rcnn_feature_extractor(
Expand Down
2 changes: 1 addition & 1 deletion research/object_detection/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, num_classes):
Args:
num_classes: number of classes. Note that num_classes *does not* include
background categories that might be implicitly be predicted in various
background categories that might be implicitly predicted in various
implementations.
"""
self._num_classes = num_classes
Expand Down
20 changes: 15 additions & 5 deletions research/object_detection/meta_architectures/ssd_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def __init__(self,
add_summaries=True,
normalize_loc_loss_by_codesize=False,
freeze_batchnorm=False,
inplace_batchnorm_update=False):
inplace_batchnorm_update=False,
add_background_class=True):
"""SSDMetaArch Constructor.
TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
Expand Down Expand Up @@ -193,6 +194,10 @@ def __init__(self,
values inplace. When this is false train op must add a control
dependency on tf.graphkeys.UPDATE_OPS collection in order to update
batch norm statistics.
add_background_class: Whether to add an implicit background class to
one-hot encodings of groundtruth labels. Set to false if using
groundtruth labels with an explicit background class or using multiclass
scores instead of truth in the case of distillation.
"""
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training
Expand All @@ -210,6 +215,7 @@ def __init__(self,
self._feature_extractor = feature_extractor
self._matcher = matcher
self._region_similarity_calculator = region_similarity_calculator
self._add_background_class = add_background_class

# TODO(jonathanhuang): handle agnostic mode
# weights
Expand Down Expand Up @@ -636,10 +642,14 @@ def _assign_targets(self, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_boxlists = [
box_list.BoxList(boxes) for boxes in groundtruth_boxes_list
]
groundtruth_classes_with_background_list = [
tf.pad(one_hot_encoding, [[0, 0], [1, 0]], mode='CONSTANT')
for one_hot_encoding in groundtruth_classes_list
]
if self._add_background_class:
groundtruth_classes_with_background_list = [
tf.pad(one_hot_encoding, [[0, 0], [1, 0]], mode='CONSTANT')
for one_hot_encoding in groundtruth_classes_list
]
else:
groundtruth_classes_with_background_list = groundtruth_classes_list

if groundtruth_keypoints_list is not None:
for boxlist, keypoints in zip(
groundtruth_boxlists, groundtruth_keypoints_list):
Expand Down
72 changes: 63 additions & 9 deletions research/object_detection/meta_architectures/ssd_meta_arch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def _get_value_for_matching_key(dictionary, suffix):

class SsdMetaArchTest(test_case.TestCase):

def _create_model(self, apply_hard_mining=True,
normalize_loc_loss_by_codesize=False):
def _create_model(self,
apply_hard_mining=True,
normalize_loc_loss_by_codesize=False,
add_background_class=True):
is_training = False
num_classes = 1
mock_anchor_generator = MockAnchorGenerator2x2()
Expand Down Expand Up @@ -117,14 +119,29 @@ def image_resizer_fn(image):

code_size = 4
model = ssd_meta_arch.SSDMetaArch(
is_training, mock_anchor_generator, mock_box_predictor, mock_box_coder,
fake_feature_extractor, mock_matcher, region_similarity_calculator,
encode_background_as_zeros, negative_class_weight, image_resizer_fn,
non_max_suppression_fn, tf.identity, classification_loss,
localization_loss, classification_loss_weight, localization_loss_weight,
normalize_loss_by_num_matches, hard_example_miner, add_summaries=False,
is_training,
mock_anchor_generator,
mock_box_predictor,
mock_box_coder,
fake_feature_extractor,
mock_matcher,
region_similarity_calculator,
encode_background_as_zeros,
negative_class_weight,
image_resizer_fn,
non_max_suppression_fn,
tf.identity,
classification_loss,
localization_loss,
classification_loss_weight,
localization_loss_weight,
normalize_loss_by_num_matches,
hard_example_miner,
add_summaries=False,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=False, inplace_batchnorm_update=False)
freeze_batchnorm=False,
inplace_batchnorm_update=False,
add_background_class=add_background_class)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size

def test_preprocess_preserves_shapes_with_dynamic_input_image(self):
Expand Down Expand Up @@ -365,6 +382,43 @@ def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2,
self.assertAllClose(localization_loss, expected_localization_loss)
self.assertAllClose(classification_loss, expected_classification_loss)

def test_loss_results_are_correct_without_add_background_class(self):

with tf.Graph().as_default():
_, num_classes, num_anchors, _ = self._create_model(
add_background_class=False)

def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2):
groundtruth_boxes_list = [groundtruth_boxes1, groundtruth_boxes2]
groundtruth_classes_list = [groundtruth_classes1, groundtruth_classes2]
model, _, _, _ = self._create_model(
apply_hard_mining=False, add_background_class=False)
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
prediction_dict = model.predict(
preprocessed_tensor, true_image_shapes=None)
loss_dict = model.loss(prediction_dict, true_image_shapes=None)
return (loss_dict['Loss/localization_loss'],
loss_dict['Loss/classification_loss'])

batch_size = 2
preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32)
groundtruth_boxes1 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_boxes2 = np.array([[0, 0, .5, .5]], dtype=np.float32)
groundtruth_classes1 = np.array([[0, 1]], dtype=np.float32)
groundtruth_classes2 = np.array([[0, 1]], dtype=np.float32)
expected_localization_loss = 0.0
expected_classification_loss = (
batch_size * num_anchors * (num_classes + 1) * np.log(2.0))
(localization_loss, classification_loss) = self.execute(
graph_fn, [
preprocessed_input, groundtruth_boxes1, groundtruth_boxes2,
groundtruth_classes1, groundtruth_classes2
])
self.assertAllClose(localization_loss, expected_localization_loss)
self.assertAllClose(classification_loss, expected_classification_loss)

def test_restore_map_for_detection_ckpt(self):
model, _, _, _ = self._create_model()
model.predict(tf.constant(np.array([[[0, 0], [1, 1]], [[1, 0], [0, 1]]],
Expand Down
5 changes: 5 additions & 0 deletions research/object_detection/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import "object_detection/protos/optimizer.proto";
import "object_detection/protos/preprocessor.proto";

// Message for configuring DetectionModel training jobs (train.py).
// Next id: 25
message TrainConfig {
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
Expand Down Expand Up @@ -80,6 +81,10 @@ message TrainConfig {
// Note that only Sigmoid classification losses should be used.
optional bool merge_multiple_label_boxes = 17 [default=false];

// If true, will use multiclass scores from object annotations as ground
// truth. Currently only compatible with annotated image inputs.
optional bool use_multiclass_scores = 24 [default = false];

// Whether to add regularization loss to `total_loss`. This is true by
// default and adds all regularization losses defined in the model to
// `total_loss`.
Expand Down

0 comments on commit eccae44

Please sign in to comment.