Skip to content

Commit

Permalink
Merge pull request google#223 from google/sync
Browse files Browse the repository at this point in the history
Sync
  • Loading branch information
mingxingtan authored Apr 13, 2020
2 parents 0eeee67 + 523778f commit e596888
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 113 deletions.
40 changes: 23 additions & 17 deletions efficientdet/anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _generate_anchor_configs(feat_sizes, min_level, max_level, num_scales,
A configuration is a tuple of (num_anchors, scale, aspect_ratio).
Args:
feat_sizes: list of integer numbers of feature map sizes.
feat_sizes: list of dict of integer numbers of feature map sizes.
min_level: integer number of minimum level of the output feature pyramid.
max_level: integer number of maximum level of the output feature pyramid.
num_scales: integer number representing intermediate scales added
Expand All @@ -203,17 +203,18 @@ def _generate_anchor_configs(feat_sizes, min_level, max_level, num_scales,
anchor_configs[level] = []
for scale_octave in range(num_scales):
for aspect in aspect_ratios:
anchor_configs[level].append((feat_sizes[0] / float(feat_sizes[level]),
scale_octave / float(num_scales), aspect))
anchor_configs[level].append(
((feat_sizes[0]['height'] / float(feat_sizes[level]['height']),
feat_sizes[0]['width'] / float(feat_sizes[level]['width'])),
scale_octave / float(num_scales), aspect))
return anchor_configs


def _generate_anchor_boxes(image_size, anchor_scale, anchor_configs):
"""Generates multiscale anchor boxes.
Args:
image_size: integer number of input image size. The input image has the
same dimension for width and height.
image_size: tuple of integer numbers of input image size.
anchor_scale: float number representing the scale of size of the base
anchor to the feature stride 2^level.
anchor_configs: a dictionary with keys as the levels of anchors and
Expand All @@ -230,12 +231,13 @@ def _generate_anchor_boxes(image_size, anchor_scale, anchor_configs):
boxes_level = []
for config in configs:
stride, octave_scale, aspect = config
base_anchor_size = anchor_scale * stride * 2**octave_scale
anchor_size_x_2 = base_anchor_size * aspect[0] / 2.0
anchor_size_y_2 = base_anchor_size * aspect[1] / 2.0
base_anchor_size_x = anchor_scale * stride[1] * 2**octave_scale
base_anchor_size_y = anchor_scale * stride[0] * 2**octave_scale
anchor_size_x_2 = base_anchor_size_x * aspect[0] / 2.0
anchor_size_y_2 = base_anchor_size_y * aspect[1] / 2.0

x = np.arange(stride / 2, image_size, stride)
y = np.arange(stride / 2, image_size, stride)
x = np.arange(stride[1] / 2, image_size[1], stride[1])
y = np.arange(stride[0] / 2, image_size[0], stride[0])
xv, yv = np.meshgrid(x, y)
xv = xv.reshape(-1)
yv = yv.reshape(-1)
Expand Down Expand Up @@ -438,10 +440,10 @@ def _generate_dummy_detections(number):
n = max(MAX_DETECTIONS_PER_IMAGE - len(detections), 0)
detections_dummy = _generate_dummy_detections(n)
detections = np.vstack([detections, detections_dummy])
detections[:, 1:5] *= image_scale
else:
detections = _generate_dummy_detections(MAX_DETECTIONS_PER_IMAGE)
detections[:, 1:5] *= image_scale

detections[:, 1:5] *= image_scale

return detections

Expand All @@ -464,15 +466,17 @@ def __init__(self, min_level, max_level, num_scales, aspect_ratios,
[(1, 1), (1.4, 0.7), (0.7, 1.4)] adds three anchors on each level.
anchor_scale: float number representing the scale of size of the base
anchor to the feature stride 2^level.
image_size: integer number of input image size. The input image has the
same dimension for width and height.
image_size: integer number or tuple of integer number of input image size.
"""
self.min_level = min_level
self.max_level = max_level
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
self.anchor_scale = anchor_scale
self.image_size = image_size
if isinstance(image_size, int):
self.image_size = (image_size, image_size)
else:
self.image_size = image_size
self.feat_sizes = utils.get_feat_sizes(image_size, max_level)
self.config = self._generate_configs()
self.boxes = self._generate_boxes()
Expand Down Expand Up @@ -527,11 +531,13 @@ def _unpack_labels(self, labels):
count = 0
for level in range(anchors.min_level, anchors.max_level + 1):
feat_size = anchors.feat_sizes[level]
steps = feat_size**2 * anchors.get_anchors_per_location()
steps = feat_size['height'] * feat_size[
'width'] * anchors.get_anchors_per_location()
indices = tf.range(count, count + steps)
count += steps
labels_unpacked[level] = tf.reshape(
tf.gather(labels, indices), [feat_size, feat_size, -1])
tf.gather(labels, indices),
[feat_size['height'], feat_size['width'], -1])
return labels_unpacked

def label_anchors(self, gt_boxes, gt_labels):
Expand Down
34 changes: 20 additions & 14 deletions efficientdet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def __init__(self, image, output_size):
function.
"""
self._image = image
self._output_size = output_size
if isinstance(output_size, int):
self._output_size = (output_size, output_size)
else:
self._output_size = output_size
# Parameters to control rescaling and shifting during preprocessing.
# Image scale defines scale from original image to scaled image.
self._image_scale = tf.constant(1.0)
Expand Down Expand Up @@ -68,20 +71,22 @@ def set_training_random_scale_factors(self, scale_min, scale_max):
"""Set the parameters for multiscale training."""
# Select a random scale factor.
random_scale_factor = tf.random_uniform([], scale_min, scale_max)
scaled_size = tf.to_int32(random_scale_factor * self._output_size)
scaled_size_y = tf.to_int32(random_scale_factor * self._output_size[0])
scaled_size_x = tf.to_int32(random_scale_factor * self._output_size[1])

# Recompute the accurate scale_factor using rounded scaled image size.
height = tf.shape(self._image)[0]
width = tf.shape(self._image)[1]
max_image_size = tf.to_float(tf.maximum(height, width))
image_scale = tf.to_float(scaled_size) / max_image_size
image_scale_y = tf.to_float(scaled_size_y) / tf.to_float(height)
image_scale_x = tf.to_float(scaled_size_x) / tf.to_float(width)
image_scale = tf.minimum(image_scale_x, image_scale_y)

# Select non-zero random offset (x, y) if scaled image is larger than
# self._output_size.
scaled_height = tf.to_int32(tf.to_float(height) * image_scale)
scaled_width = tf.to_int32(tf.to_float(width) * image_scale)
offset_y = tf.to_float(scaled_height - self._output_size)
offset_x = tf.to_float(scaled_width - self._output_size)
offset_y = tf.to_float(scaled_height - self._output_size[0])
offset_x = tf.to_float(scaled_width - self._output_size[1])
offset_y = tf.maximum(0.0, offset_y) * tf.random_uniform([], 0, 1)
offset_x = tf.maximum(0.0, offset_x) * tf.random_uniform([], 0, 1)
offset_y = tf.to_int32(offset_y)
Expand All @@ -97,8 +102,9 @@ def set_scale_factors_to_output_size(self):
# Compute the scale_factor using rounded scaled image size.
height = tf.shape(self._image)[0]
width = tf.shape(self._image)[1]
max_image_size = tf.to_float(tf.maximum(height, width))
image_scale = tf.to_float(self._output_size) / max_image_size
image_scale_y = tf.to_float(self._output_size[0]) / tf.to_float(height)
image_scale_x = tf.to_float(self._output_size[1]) / tf.to_float(width)
image_scale = tf.minimum(image_scale_x, image_scale_y)
scaled_height = tf.to_int32(tf.to_float(height) * image_scale)
scaled_width = tf.to_int32(tf.to_float(width) * image_scale)
self._image_scale = image_scale
Expand All @@ -110,10 +116,10 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
scaled_image = tf.image.resize_images(
self._image, [self._scaled_height, self._scaled_width], method=method)
scaled_image = scaled_image[
self._crop_offset_y:self._crop_offset_y + self._output_size,
self._crop_offset_x:self._crop_offset_x + self._output_size, :]
self._crop_offset_y:self._crop_offset_y + self._output_size[0],
self._crop_offset_x:self._crop_offset_x + self._output_size[1], :]
output_image = tf.image.pad_to_bounding_box(
scaled_image, 0, 0, self._output_size, self._output_size)
scaled_image, 0, 0, self._output_size[0], self._output_size[1])
return output_image


Expand All @@ -133,8 +139,8 @@ def random_horizontal_flip(self):
def clip_boxes(self, boxes):
"""Clip boxes to fit in an image."""
boxes = tf.where(tf.less(boxes, 0), tf.zeros_like(boxes), boxes)
boxes = tf.where(tf.greater(boxes, self._output_size - 1),
(self._output_size - 1) * tf.ones_like(boxes), boxes)
boxes = tf.where(tf.greater(boxes, self._output_size[0] - 1),
(self._output_size[1] - 1) * tf.ones_like(boxes), boxes)
return boxes

def resize_and_crop_boxes(self):
Expand Down Expand Up @@ -224,7 +230,7 @@ def _dataset_parser(value):
Returns:
image: Image tensor that is preprocessed to have normalized value and
fixed dimension [image_size, image_size, 3]
fixed dimension [image_height, image_width, 3]
cls_targets_dict: ordered dictionary with keys
[min_level, min_level+1, ..., max_level]. The values are tensor with
shape [height_l, width_l, num_anchors]. The height_l and width_l
Expand Down
Loading

0 comments on commit e596888

Please sign in to comment.