forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: semantic segmentation PointRend Reviewed By: ppwwyyxx Differential Revision: D19350389 fbshipit-source-id: cec04422ae5b76d730257de336d871cf281b625a
- Loading branch information
1 parent
9763402
commit 9725ea1
Showing
7 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
19 changes: 19 additions & 0 deletions
19
projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
_BASE_: "../../../../configs/Base-RCNN-FPN.yaml" | ||
MODEL: | ||
META_ARCHITECTURE: "SemanticSegmentor" | ||
BACKBONE: | ||
FREEZE_AT: 0 | ||
SEM_SEG_HEAD: | ||
NAME: "PointRendSemSegHead" | ||
POINT_HEAD: | ||
NUM_CLASSES: 54 | ||
FC_DIM: 256 | ||
NUM_FC: 3 | ||
IN_FEATURES: ["p2"] | ||
TRAIN_NUM_POINTS: 1024 | ||
SUBDIVISION_STEPS: 2 | ||
SUBDIVISION_NUM_POINTS: 8192 | ||
COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead" | ||
DATASETS: | ||
TRAIN: ("coco_2017_train_panoptic_stuffonly",) | ||
TEST: ("coco_2017_val_panoptic_stuffonly",) |
30 changes: 30 additions & 0 deletions
30
...cts/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_50_FPN_1x_cityscapes.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
_BASE_: Base-PointRend-Semantic-FPN.yaml | ||
MODEL: | ||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl | ||
RESNETS: | ||
DEPTH: 101 | ||
SEM_SEG_HEAD: | ||
NUM_CLASSES: 19 | ||
POINT_HEAD: | ||
NUM_CLASSES: 19 | ||
TRAIN_NUM_POINTS: 2048 | ||
SUBDIVISION_NUM_POINTS: 8192 | ||
DATASETS: | ||
TRAIN: ("cityscapes_fine_sem_seg_train",) | ||
TEST: ("cityscapes_fine_sem_seg_val",) | ||
SOLVER: | ||
BASE_LR: 0.01 | ||
STEPS: (40000, 55000) | ||
MAX_ITER: 65000 | ||
IMS_PER_BATCH: 32 | ||
INPUT: | ||
MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048) | ||
MIN_SIZE_TRAIN_SAMPLING: "choice" | ||
MIN_SIZE_TEST: 1024 | ||
MAX_SIZE_TRAIN: 4096 | ||
MAX_SIZE_TEST: 2048 | ||
CROP: | ||
ENABLED: True | ||
TYPE: "absolute" | ||
SIZE: (512, 1024) | ||
SINGLE_CATEGORY_MAX_AREA: 0.75 |
5 changes: 5 additions & 0 deletions
5
projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_50_FPN_1x_coco.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_BASE_: Base-PointRend-Semantic-FPN.yaml | ||
MODEL: | ||
WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl | ||
RESNETS: | ||
DEPTH: 50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import numpy as np | ||
from typing import Dict | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from detectron2.layers import ShapeSpec, cat | ||
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | ||
|
||
from .point_features import ( | ||
get_uncertain_point_coords_on_grid, | ||
get_uncertain_point_coords_with_randomness, | ||
point_sample, | ||
) | ||
from .point_head import build_point_head | ||
|
||
|
||
def calculate_uncertainty(sem_seg_logits): | ||
""" | ||
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the | ||
difference between top first and top second predicted logits. | ||
Args: | ||
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and | ||
C is the number of foreground classes. The values are logits. | ||
Returns: | ||
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with | ||
the most uncertain locations having the highest uncertainty score. | ||
""" | ||
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] | ||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) | ||
|
||
|
||
@SEM_SEG_HEADS_REGISTRY.register() | ||
class PointRendSemSegHead(nn.Module): | ||
""" | ||
A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` | ||
and a point head set in `MODEL.POINT_HEAD.NAME`. | ||
""" | ||
|
||
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): | ||
super().__init__() | ||
|
||
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE | ||
|
||
self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( | ||
cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME | ||
)(cfg, input_shape) | ||
self._init_point_head(cfg, input_shape) | ||
|
||
def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): | ||
# fmt: off | ||
assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES | ||
feature_channels = {k: v.channels for k, v in input_shape.items()} | ||
self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES | ||
self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS | ||
self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO | ||
self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO | ||
self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS | ||
self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS | ||
# fmt: on | ||
|
||
in_channels = np.sum([feature_channels[f] for f in self.in_features]) | ||
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) | ||
|
||
def forward(self, features, targets=None): | ||
coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) | ||
|
||
if self.training: | ||
losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) | ||
|
||
with torch.no_grad(): | ||
point_coords = get_uncertain_point_coords_with_randomness( | ||
coarse_sem_seg_logits, | ||
calculate_uncertainty, | ||
self.train_num_points, | ||
self.oversample_ratio, | ||
self.importance_sample_ratio, | ||
) | ||
coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) | ||
|
||
fine_grained_features = cat( | ||
[ | ||
point_sample(features[in_feature], point_coords, align_corners=False) | ||
for in_feature in self.in_features | ||
] | ||
) | ||
point_logits = self.point_head(fine_grained_features, coarse_features) | ||
point_targets = ( | ||
point_sample( | ||
targets.unsqueeze(1).to(torch.float), | ||
point_coords, | ||
mode="nearest", | ||
align_corners=False, | ||
) | ||
.squeeze(1) | ||
.to(torch.long) | ||
) | ||
losses["loss_sem_seg_point"] = F.cross_entropy( | ||
point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value | ||
) | ||
return None, losses | ||
else: | ||
sem_seg_logits = coarse_sem_seg_logits.clone() | ||
for _ in range(self.subdivision_steps): | ||
sem_seg_logits = F.interpolate( | ||
sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False | ||
) | ||
uncertainty_map = calculate_uncertainty(sem_seg_logits) | ||
point_indices, point_coords = get_uncertain_point_coords_on_grid( | ||
uncertainty_map, self.subdivision_num_points | ||
) | ||
fine_grained_features = cat( | ||
[ | ||
point_sample(features[in_feature], point_coords, align_corners=False) | ||
for in_feature in self.in_features | ||
] | ||
) | ||
coarse_features = point_sample( | ||
coarse_sem_seg_logits, point_coords, align_corners=False | ||
) | ||
point_logits = self.point_head(fine_grained_features, coarse_features) | ||
|
||
# put sem seg point predictions to the right places on the upsampled grid. | ||
N, C, H, W = sem_seg_logits.shape | ||
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) | ||
sem_seg_logits = ( | ||
sem_seg_logits.reshape(N, C, H * W) | ||
.scatter_(2, point_indices, point_logits) | ||
.view(N, C, H, W) | ||
) | ||
return sem_seg_logits, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters