Skip to content

Commit

Permalink
Cropping for semantic segmentation models
Browse files Browse the repository at this point in the history
Summary: Cropping for semantic segmentation

Reviewed By: ppwwyyxx

Differential Revision: D19453097

fbshipit-source-id: a64d462e068d22b7e3be3703c16df3839474ea71
  • Loading branch information
alexander-kirillov authored and facebook-github-bot committed May 1, 2020
1 parent e7aa570 commit 9763402
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 3 deletions.
2 changes: 1 addition & 1 deletion detectron2/data/transforms/transform_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def get_crop_size(self, image_size):
ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
return int(h * ch + 0.5), int(w * cw + 0.5)
elif self.crop_type == "absolute":
return self.crop_size
return (min(self.crop_size[0], h), min(self.crop_size[1], w))
else:
NotImplementedError("Unknown crop type {}".format(self.crop_type))

Expand Down
1 change: 1 addition & 0 deletions projects/PointRend/point_rend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .config import add_pointrend_config
from .coarse_mask_head import CoarseMaskHead
from .roi_heads import PointRendROIHeads
from .dataset_mapper import SemSegDatasetMapper
4 changes: 4 additions & 0 deletions projects/PointRend/point_rend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ def add_pointrend_config(cfg):
"""
Add config for PointRend.
"""
# We retry random cropping until no single category in semantic segmentation GT occupies more
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0

# Names of the input feature maps to be used by a coarse mask head.
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024
Expand Down
113 changes: 113 additions & 0 deletions projects/PointRend/point_rend/dataset_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import torch
from fvcore.common.file_io import PathManager
from fvcore.transforms.transform import CropTransform
from PIL import Image

from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T

"""
This file contains the mapping that's applied to "dataset dicts" for semantic segmentation models.
Unlike the default DatasetMapper this mapper uses cropping as the last transformation.
"""

__all__ = ["SemSegDatasetMapper"]


class SemSegDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by semantic segmentation models.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""

def __init__(self, cfg, is_train=True):
if cfg.INPUT.CROP.ENABLED and is_train:
self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
else:
self.crop_gen = None

self.tfm_gens = utils.build_transform_gen(cfg, is_train)

# fmt: off
self.img_format = cfg.INPUT.FORMAT
self.single_category_max_area = cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
# fmt: on

self.is_train = is_train

def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
assert "sem_seg_file_name" in dataset_dict

image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.is_train:
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
sem_seg_gt = Image.open(f)
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
if self.crop_gen:
image, sem_seg_gt = crop_transform(
image,
sem_seg_gt,
self.crop_gen,
self.single_category_max_area,
self.ignore_value,
)
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))

# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))

if not self.is_train:
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict

return dataset_dict


def crop_transform(image, sem_seg, crop_gen, single_category_max_area, ignore_value):
"""
Find a cropping window such that no single category occupies more than
`single_category_max_area` in `sem_seg`. The function retries random cropping 10 times max.
"""
if single_category_max_area >= 1.0:
crop_tfm = crop_gen.get_transform(image)
sem_seg_temp = crop_tfm.apply_segmentation(sem_seg)
else:
h, w = sem_seg.shape
crop_size = crop_gen.get_crop_size((h, w))
for _ in range(10):
y0 = np.random.randint(h - crop_size[0] + 1)
x0 = np.random.randint(w - crop_size[1] + 1)
sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
labels, cnt = np.unique(sem_seg_temp, return_counts=True)
cnt = cnt[labels != ignore_value]
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < single_category_max_area:
break
crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
image = crop_tfm.apply_image(image)
return image, sem_seg_temp
12 changes: 10 additions & 2 deletions projects/PointRend/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
Expand All @@ -24,7 +24,7 @@
verify_results,
)

from point_rend import add_pointrend_config
from point_rend import SemSegDatasetMapper, add_pointrend_config


class Trainer(DefaultTrainer):
Expand Down Expand Up @@ -71,6 +71,14 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)

@classmethod
def build_train_loader(cls, cfg):
if "SemanticSegmentor" in cfg.MODEL.META_ARCHITECTURE:
mapper = SemSegDatasetMapper(cfg, True)
else:
mapper = None
return build_detection_train_loader(cfg, mapper=mapper)


def setup(args):
"""
Expand Down

0 comments on commit 9763402

Please sign in to comment.