Skip to content

Commit 205c5e0

Browse files
author
Jonathan Huang
authored
Merge pull request tensorflow#2725 from tombstone/fix_difficult_list_eval
Fix tensorflow#2713
2 parents f88def2 + 9742585 commit 205c5e0

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

research/object_detection/utils/object_detection_evaluation.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,26 @@ def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
166166
groundtruth_classes = groundtruth_dict[
167167
standard_fields.InputDataFields.groundtruth_classes]
168168
groundtruth_classes -= self._label_id_offset
169+
# If the key is not present in the groundtruth_dict or the array is empty
170+
# (unless there are no annotations for the groundtruth on this image)
171+
# use values from the dictionary or insert None otherwise.
172+
if (standard_fields.InputDataFields.groundtruth_difficult in
173+
groundtruth_dict.keys() and
174+
(groundtruth_dict[standard_fields.InputDataFields.groundtruth_difficult]
175+
.size or not groundtruth_classes.size)):
176+
groundtruth_difficult = groundtruth_dict[
177+
standard_fields.InputDataFields.groundtruth_difficult]
178+
else:
179+
groundtruth_difficult = None
180+
if not len(self._image_ids) % 1000:
181+
logging.warn(
182+
'image %s does not have groundtruth difficult flag specified',
183+
image_id)
169184
self._evaluation.add_single_ground_truth_image_info(
170185
image_id,
171186
groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],
172187
groundtruth_classes,
173-
groundtruth_dict.get(
174-
standard_fields.InputDataFields.groundtruth_difficult, None))
188+
groundtruth_is_difficult_list=groundtruth_difficult)
175189
self._image_ids.update([image_id])
176190

177191
def add_single_detected_image_info(self, image_id, detections_dict):
@@ -337,14 +351,27 @@ def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
337351
groundtruth_classes = groundtruth_dict[
338352
standard_fields.InputDataFields.groundtruth_classes]
339353
groundtruth_classes -= self._label_id_offset
340-
354+
# If the key is not present in the groundtruth_dict or the array is empty
355+
# (unless there are no annotations for the groundtruth on this image)
356+
# use values from the dictionary or insert None otherwise.
357+
if (standard_fields.InputDataFields.groundtruth_group_of in
358+
groundtruth_dict.keys() and
359+
(groundtruth_dict[standard_fields.InputDataFields.groundtruth_group_of]
360+
.size or not groundtruth_classes.size)):
361+
groundtruth_group_of = groundtruth_dict[
362+
standard_fields.InputDataFields.groundtruth_group_of]
363+
else:
364+
groundtruth_group_of = None
365+
if not len(self._image_ids) % 1000:
366+
logging.warn(
367+
'image %s does not have groundtruth group_of flag specified',
368+
image_id)
341369
self._evaluation.add_single_ground_truth_image_info(
342370
image_id,
343371
groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],
344372
groundtruth_classes,
345373
groundtruth_is_difficult_list=None,
346-
groundtruth_is_group_of_list=groundtruth_dict.get(
347-
standard_fields.InputDataFields.groundtruth_group_of, None))
374+
groundtruth_is_group_of_list=groundtruth_group_of)
348375
self._image_ids.update([image_id])
349376

350377

research/object_detection/utils/object_detection_evaluation_test.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def test_returns_correct_metric_values(self):
4646
standard_fields.InputDataFields.groundtruth_boxes:
4747
groundtruth_boxes1,
4848
standard_fields.InputDataFields.groundtruth_classes:
49-
groundtruth_class_labels1
49+
groundtruth_class_labels1,
50+
standard_fields.InputDataFields.groundtruth_group_of:
51+
np.array([], dtype=bool)
5052
})
5153
image_key2 = 'img2'
5254
groundtruth_boxes2 = np.array(
@@ -115,7 +117,9 @@ def test_returns_correct_metric_values(self):
115117
image_key1,
116118
{standard_fields.InputDataFields.groundtruth_boxes: groundtruth_boxes1,
117119
standard_fields.InputDataFields.groundtruth_classes:
118-
groundtruth_class_labels1})
120+
groundtruth_class_labels1,
121+
standard_fields.InputDataFields.groundtruth_difficult:
122+
np.array([], dtype=bool)})
119123
image_key2 = 'img2'
120124
groundtruth_boxes2 = np.array([[10, 10, 11, 11], [500, 500, 510, 510],
121125
[10, 10, 12, 12]], dtype=float)

0 commit comments

Comments
 (0)