Skip to content

Commit

Permalink
PointRend semantic segmentation
Browse files Browse the repository at this point in the history
Summary: semantic segmentation PointRend

Reviewed By: ppwwyyxx

Differential Revision: D19350389

fbshipit-source-id: cec04422ae5b76d730257de336d871cf281b625a
  • Loading branch information
alexander-kirillov authored and facebook-github-bot committed May 1, 2020
1 parent 9763402 commit 9725ea1
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml"
MODEL:
META_ARCHITECTURE: "SemanticSegmentor"
BACKBONE:
FREEZE_AT: 0
SEM_SEG_HEAD:
NAME: "PointRendSemSegHead"
POINT_HEAD:
NUM_CLASSES: 54
FC_DIM: 256
NUM_FC: 3
IN_FEATURES: ["p2"]
TRAIN_NUM_POINTS: 1024
SUBDIVISION_STEPS: 2
SUBDIVISION_NUM_POINTS: 8192
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead"
DATASETS:
TRAIN: ("coco_2017_train_panoptic_stuffonly",)
TEST: ("coco_2017_val_panoptic_stuffonly",)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl
RESNETS:
DEPTH: 101
SEM_SEG_HEAD:
NUM_CLASSES: 19
POINT_HEAD:
NUM_CLASSES: 19
TRAIN_NUM_POINTS: 2048
SUBDIVISION_NUM_POINTS: 8192
DATASETS:
TRAIN: ("cityscapes_fine_sem_seg_train",)
TEST: ("cityscapes_fine_sem_seg_val",)
SOLVER:
BASE_LR: 0.01
STEPS: (40000, 55000)
MAX_ITER: 65000
IMS_PER_BATCH: 32
INPUT:
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048)
MIN_SIZE_TRAIN_SAMPLING: "choice"
MIN_SIZE_TEST: 1024
MAX_SIZE_TRAIN: 4096
MAX_SIZE_TEST: 2048
CROP:
ENABLED: True
TYPE: "absolute"
SIZE: (512, 1024)
SINGLE_CATEGORY_MAX_AREA: 0.75
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_BASE_: Base-PointRend-Semantic-FPN.yaml
MODEL:
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl
RESNETS:
DEPTH: 50
1 change: 1 addition & 0 deletions projects/PointRend/point_rend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .coarse_mask_head import CoarseMaskHead
from .roi_heads import PointRendROIHeads
from .dataset_mapper import SemSegDatasetMapper
from .semantic_seg import PointRendSemSegHead
1 change: 1 addition & 0 deletions projects/PointRend/point_rend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ def add_pointrend_config(cfg):
cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
# If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"
134 changes: 134 additions & 0 deletions projects/PointRend/point_rend/semantic_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
from typing import Dict
import torch
from torch import nn
from torch.nn import functional as F

from detectron2.layers import ShapeSpec, cat
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY

from .point_features import (
get_uncertain_point_coords_on_grid,
get_uncertain_point_coords_with_randomness,
point_sample,
)
from .point_head import build_point_head


def calculate_uncertainty(sem_seg_logits):
"""
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
difference between top first and top second predicted logits.
Args:
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
C is the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)


@SEM_SEG_HEADS_REGISTRY.register()
class PointRendSemSegHead(nn.Module):
"""
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME`
and a point head set in `MODEL.POINT_HEAD.NAME`.
"""

def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
super().__init__()

self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE

self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get(
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME
)(cfg, input_shape)
self._init_point_head(cfg, input_shape)

def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]):
# fmt: off
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
feature_channels = {k: v.channels for k, v in input_shape.items()}
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
# fmt: on

in_channels = np.sum([feature_channels[f] for f in self.in_features])
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1))

def forward(self, features, targets=None):
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features)

if self.training:
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets)

with torch.no_grad():
point_coords = get_uncertain_point_coords_with_randomness(
coarse_sem_seg_logits,
calculate_uncertainty,
self.train_num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False)

fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
]
)
point_logits = self.point_head(fine_grained_features, coarse_features)
point_targets = (
point_sample(
targets.unsqueeze(1).to(torch.float),
point_coords,
mode="nearest",
align_corners=False,
)
.squeeze(1)
.to(torch.long)
)
losses["loss_sem_seg_point"] = F.cross_entropy(
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value
)
return None, losses
else:
sem_seg_logits = coarse_sem_seg_logits.clone()
for _ in range(self.subdivision_steps):
sem_seg_logits = F.interpolate(
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False
)
uncertainty_map = calculate_uncertainty(sem_seg_logits)
point_indices, point_coords = get_uncertain_point_coords_on_grid(
uncertainty_map, self.subdivision_num_points
)
fine_grained_features = cat(
[
point_sample(features[in_feature], point_coords, align_corners=False)
for in_feature in self.in_features
]
)
coarse_features = point_sample(
coarse_sem_seg_logits, point_coords, align_corners=False
)
point_logits = self.point_head(fine_grained_features, coarse_features)

# put sem seg point predictions to the right places on the upsampled grid.
N, C, H, W = sem_seg_logits.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
sem_seg_logits = (
sem_seg_logits.reshape(N, C, H * W)
.scatter_(2, point_indices, point_logits)
.view(N, C, H, W)
)
return sem_seg_logits, {}
9 changes: 9 additions & 0 deletions projects/PointRend/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
COCOEvaluator,
DatasetEvaluators,
LVISEvaluator,
SemSegEvaluator,
verify_results,
)

Expand Down Expand Up @@ -51,6 +52,14 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
return LVISEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "sem_seg":
return SemSegEvaluator(
dataset_name,
distributed=True,
num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
output_dir=output_folder,
)
if evaluator_type == "cityscapes_instance":
assert (
torch.cuda.device_count() >= comm.get_rank()
Expand Down

0 comments on commit 9725ea1

Please sign in to comment.