forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cropping for semantic segmentation models
Summary: Cropping for semantic segmentation Reviewed By: ppwwyyxx Differential Revision: D19453097 fbshipit-source-id: a64d462e068d22b7e3be3703c16df3839474ea71
- Loading branch information
1 parent
e7aa570
commit 9763402
Showing
5 changed files
with
129 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import copy | ||
import logging | ||
import numpy as np | ||
import torch | ||
from fvcore.common.file_io import PathManager | ||
from fvcore.transforms.transform import CropTransform | ||
from PIL import Image | ||
|
||
from detectron2.data import detection_utils as utils | ||
from detectron2.data import transforms as T | ||
|
||
""" | ||
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models. | ||
Unlike the default DatasetMapper this mapper uses cropping as the last transformation. | ||
""" | ||
|
||
__all__ = ["SemSegDatasetMapper"] | ||
|
||
|
||
class SemSegDatasetMapper: | ||
""" | ||
A callable which takes a dataset dict in Detectron2 Dataset format, | ||
and map it into a format used by semantic segmentation models. | ||
The callable currently does the following: | ||
1. Read the image from "file_name" | ||
2. Applies geometric transforms to the image and annotation | ||
3. Find and applies suitable cropping to the image and annotation | ||
4. Prepare image and annotation to Tensors | ||
""" | ||
|
||
def __init__(self, cfg, is_train=True): | ||
if cfg.INPUT.CROP.ENABLED and is_train: | ||
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE) | ||
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen)) | ||
else: | ||
self.crop_gen = None | ||
|
||
self.tfm_gens = utils.build_transform_gen(cfg, is_train) | ||
|
||
# fmt: off | ||
self.img_format = cfg.INPUT.FORMAT | ||
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA | ||
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE | ||
# fmt: on | ||
|
||
self.is_train = is_train | ||
|
||
def __call__(self, dataset_dict): | ||
""" | ||
Args: | ||
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. | ||
Returns: | ||
dict: a format that builtin models in detectron2 accept | ||
""" | ||
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below | ||
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) | ||
utils.check_image_size(dataset_dict, image) | ||
assert "sem_seg_file_name" in dataset_dict | ||
|
||
image, transforms = T.apply_transform_gens(self.tfm_gens, image) | ||
if self.is_train: | ||
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: | ||
sem_seg_gt = Image.open(f) | ||
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") | ||
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) | ||
if self.crop_gen: | ||
image, sem_seg_gt = crop_transform( | ||
image, | ||
sem_seg_gt, | ||
self.crop_gen, | ||
self.single_category_max_area, | ||
self.ignore_value, | ||
) | ||
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) | ||
|
||
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, | ||
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. | ||
# Therefore it's important to use torch.Tensor. | ||
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) | ||
|
||
if not self.is_train: | ||
dataset_dict.pop("sem_seg_file_name", None) | ||
return dataset_dict | ||
|
||
return dataset_dict | ||
|
||
|
||
def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value): | ||
""" | ||
Find a cropping window such that no single category occupies more than | ||
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max. | ||
""" | ||
if single_category_max_area >= 1.0: | ||
crop_tfm = crop_gen.get_transform(image) | ||
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg) | ||
else: | ||
h, w = sem_seg.shape | ||
crop_size = crop_gen.get_crop_size((h, w)) | ||
for _ in range(10): | ||
y0 = np.random.randint(h - crop_size[0] + 1) | ||
x0 = np.random.randint(w - crop_size[1] + 1) | ||
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] | ||
labels, cnt = np.unique(sem_seg_temp, return_counts=True) | ||
cnt = cnt[labels != ignore_value] | ||
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area: | ||
break | ||
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0]) | ||
image = crop_tfm.apply_image(image) | ||
return image, sem_seg_temp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters