Skip to content

Commit

Permalink
Add LimitLong Transform (PaddlePaddle#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyefeilin authored Jan 27, 2021
1 parent 6973146 commit f850954
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
17 changes: 17 additions & 0 deletions paddleseg/core/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ def get_reverse_list(ori_shape, transforms):
if op.__class__.__name__ in ['Padding']:
reverse_list.append(('padding', (h, w)))
w, h = op.target_size[0], op.target_size[1]
if op.__class__.__name__ in ['LimitLong']:
long_edge = max(h, w)
short_edge = min(h, w)
if ((op.max_long is not None) and (long_edge > op.max_long)):
reverse_list.append(('resize', (h, w)))
long_edge = op.max_long
short_edge = int(round(short_edge * op.max_long / long_edge))
elif ((op.min_long is not None) and (long_edge < op.min_long)):
reverse_list.append(('resize', (h, w)))
long_edge = op.min_long
short_edge = int(round(short_edge * op.min_long / long_edge))
if h > w:
h = long_edge
w = short_edge
else:
w = long_edge
h = short_edge
return reverse_list


Expand Down
6 changes: 4 additions & 2 deletions paddleseg/core/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def evaluate(model,
intersect_area_list = []
pred_area_list = []
label_area_list = []
paddle.distributed.all_gather(intersect_area_list, intersect_area)
paddle.distributed.all_gather(intersect_area_list,
intersect_area)
paddle.distributed.all_gather(pred_area_list, pred_area)
paddle.distributed.all_gather(label_area_list, label_area)

Expand All @@ -135,7 +136,8 @@ def evaluate(model,
label_area_list = label_area_list[:valid]

for i in range(len(intersect_area_list)):
intersect_area_all = intersect_area_all + intersect_area_list[i]
intersect_area_all = intersect_area_all + intersect_area_list[
i]
pred_area_all = pred_area_all + pred_area_list[i]
label_area_all = label_area_all + label_area_list[i]
else:
Expand Down
65 changes: 65 additions & 0 deletions paddleseg/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,71 @@ def __call__(self, im, label=None):
return (im, label)


@manager.TRANSFORMS.add_component
class LimitLong:
"""
Limit the long edge of image.
If the long edge is larger than max_long, resize the long edge
to max_long, while scale the short edge proportionally.
If the long edge is smaller than min_long, resize the long edge
to min_long, while scale the short edge proportionally.
Args:
max_long (int, optional): If the long edge of image is larger than max_long,
it will be resize to max_long. Default: None.
min_long (int, optional): If the long edge of image is smaller than min_long,
it will be resize to min_long. Default: None.
"""

def __init__(self, max_long=None, min_long=None):
if max_long is not None:
if not isinstance(max_long, int):
raise TypeError(
"Type of `max_long` is invalid. It should be int, but it is {}"
.format(type(max_long)))
if min_long is not None:
if not isinstance(min_long, int):
raise TypeError(
"Type of `min_long` is invalid. It should be int, but it is {}"
.format(type(min_long)))
if (max_long is not None) and (min_long is not None):
if min_long > max_long:
raise ValueError(
'`max_long should not smaller than min_long, but they are {} and {}'
.format(max_long, min_long))
self.max_long = max_long
self.min_long = min_long

def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
h, w = im.shape[0], im.shape[1]
long_edge = max(h, w)
target = long_edge
if (self.max_long is not None) and (long_edge > self.max_long):
target = self.max_long
elif (self.min_long is not None) and (long_edge < self.min_long):
target = self.min_long

if target != long_edge:
im = functional.resize_long(im, target)
if label is not None:
label = functional.resize_long(label, target, cv2.INTER_NEAREST)

if label is None:
return (im, )
else:
return (im, label)


@manager.TRANSFORMS.add_component
class ResizeRangeScaling:
"""
Expand Down

0 comments on commit f850954

Please sign in to comment.