forked from PaddlePaddle/PaddleSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add CCNet (PaddlePaddle#2005)
- Loading branch information
Showing
11 changed files
with
324 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# CCNet: Criss-cross attention for semantic segmentation | ||
|
||
## Reference | ||
|
||
> Zilong Huang, Xinggang Wang, Yunchao Wei, Lichao Huang, Humphrey Shi, Wenyu Liu, Thomas S. Huang. "CCNet: Criss-cross attention for semantic segmentation." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019. | ||
## Performance | ||
|
||
### Cityscapes | ||
|
||
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | | ||
|-|-|-|-|-|-|-|-| | ||
|CCNet|ResNet101_OS8|769x769|60000|80.95%|81.23%|81.32%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/ccnet_resnet101_os8_cityscapes_769x769_60k/model.pdparams)\|[log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/ccnet_resnet101_os8_cityscapes_769x769_60k/train.log)\|[vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=6828616e27a1e15f1442beb3b4834048)| |
27 changes: 27 additions & 0 deletions
27
configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
_base_: '../_base_/cityscapes_769x769.yml' | ||
|
||
batch_size: 2 | ||
iters: 60000 | ||
|
||
model: | ||
type: CCNet | ||
backbone: | ||
type: ResNet101_vd | ||
output_stride: 8 | ||
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz | ||
backbone_indices: [2, 3] | ||
enable_auxiliary_loss: True | ||
dropout_prob: 0.1 | ||
recurrence: 2 | ||
|
||
loss: | ||
types: | ||
- type: OhemCrossEntropyLoss | ||
- type: CrossEntropyLoss | ||
coef: [1, 0.4] | ||
|
||
lr_scheduler: | ||
type: PolynomialDecay | ||
learning_rate: 0.01 | ||
power: 0.9 | ||
end_lr: 1.0e-4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,3 +58,4 @@ | |
from .pfpnnet import PFPNNet | ||
from .glore import GloRe | ||
from .ddrnet import DDRNet_23 | ||
from .ccnet import CCNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
from paddleseg.cvlibs import manager | ||
from paddleseg.models import layers | ||
from paddleseg.utils import utils | ||
|
||
|
||
@manager.MODELS.add_component | ||
class CCNet(nn.Layer): | ||
""" | ||
The CCNet implementation based on PaddlePaddle. | ||
The original article refers to | ||
Zilong Huang, et al. "CCNet: Criss-Cross Attention for Semantic Segmentation" | ||
(https://arxiv.org/abs/1811.11721) | ||
Args: | ||
num_classes (int): The unique number of target classes. | ||
backbone (paddle.nn.Layer): Backbone network, currently support Resnet18_vd/Resnet34_vd/Resnet50_vd/Resnet101_vd. | ||
backbone_indices (tuple, list, optional): Two values in the tuple indicate the indices of output of backbone. Default: (2, 3). | ||
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True. | ||
dropout_prob (float, optional): The probability of dropout. Default: 0.0. | ||
recurrence (int, optional): The number of recurrent operations. Defautl: 1. | ||
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even, | ||
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False. | ||
pretrained (str, optional): The path or url of pretrained model. Default: None. | ||
""" | ||
|
||
def __init__(self, | ||
num_classes, | ||
backbone, | ||
backbone_indices=(2, 3), | ||
enable_auxiliary_loss=True, | ||
dropout_prob=0.0, | ||
recurrence=1, | ||
align_corners=False, | ||
pretrained=None): | ||
super().__init__() | ||
self.enable_auxiliary_loss = enable_auxiliary_loss | ||
self.recurrence = recurrence | ||
self.align_corners = align_corners | ||
|
||
self.backbone = backbone | ||
self.backbone_indices = backbone_indices | ||
backbone_channels = [ | ||
backbone.feat_channels[i] for i in backbone_indices | ||
] | ||
|
||
if enable_auxiliary_loss: | ||
self.aux_head = layers.AuxLayer( | ||
backbone_channels[0], | ||
512, | ||
num_classes, | ||
dropout_prob=dropout_prob) | ||
self.head = RCCAModule( | ||
backbone_channels[1], | ||
512, | ||
num_classes, | ||
dropout_prob=dropout_prob, | ||
recurrence=recurrence) | ||
self.pretrained = pretrained | ||
|
||
def init_weight(self): | ||
if self.pretrained is not None: | ||
utils.load_entire_model(self, self.pretrained) | ||
|
||
def forward(self, x): | ||
feat_list = self.backbone(x) | ||
logit_list = [] | ||
output = self.head(feat_list[self.backbone_indices[-1]]) | ||
logit_list.append(output) | ||
if self.training and self.enable_auxiliary_loss: | ||
aux_out = self.aux_head(feat_list[self.backbone_indices[-2]]) | ||
logit_list.append(aux_out) | ||
return [ | ||
F.interpolate( | ||
logit, | ||
paddle.shape(x)[2:], | ||
mode='bilinear', | ||
align_corners=self.align_corners) for logit in logit_list | ||
] | ||
|
||
|
||
class RCCAModule(nn.Layer): | ||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
num_classes, | ||
dropout_prob=0.1, | ||
recurrence=1): | ||
super().__init__() | ||
inter_channels = in_channels // 4 | ||
self.recurrence = recurrence | ||
self.conva = layers.ConvBNLeakyReLU( | ||
in_channels, inter_channels, 3, padding=1, bias_attr=False) | ||
self.cca = CrissCrossAttention(inter_channels) | ||
self.convb = layers.ConvBNLeakyReLU( | ||
inter_channels, inter_channels, 3, padding=1, bias_attr=False) | ||
self.out = layers.AuxLayer( | ||
in_channels + inter_channels, | ||
out_channels, | ||
num_classes, | ||
dropout_prob=dropout_prob) | ||
|
||
def forward(self, x): | ||
feat = self.conva(x) | ||
for i in range(self.recurrence): | ||
feat = self.cca(feat) | ||
feat = self.convb(feat) | ||
output = self.out(paddle.concat([x, feat], axis=1)) | ||
return output | ||
|
||
|
||
class CrissCrossAttention(nn.Layer): | ||
def __init__(self, in_channels): | ||
super().__init__() | ||
self.q_conv = nn.Conv2D(in_channels, in_channels // 8, kernel_size=1) | ||
self.k_conv = nn.Conv2D(in_channels, in_channels // 8, kernel_size=1) | ||
self.v_conv = nn.Conv2D(in_channels, in_channels, kernel_size=1) | ||
self.softmax = nn.Softmax(axis=3) | ||
self.gamma = self.create_parameter( | ||
shape=(1, ), default_initializer=nn.initializer.Constant(0)) | ||
self.inf_tensor = paddle.full(shape=(1, ), fill_value=float('inf')) | ||
|
||
def forward(self, x): | ||
b, c, h, w = paddle.shape(x) | ||
proj_q = self.q_conv(x) | ||
proj_q_h = proj_q.transpose([0, 3, 1, 2]).reshape( | ||
[b * w, -1, h]).transpose([0, 2, 1]) | ||
proj_q_w = proj_q.transpose([0, 2, 1, 3]).reshape( | ||
[b * h, -1, w]).transpose([0, 2, 1]) | ||
|
||
proj_k = self.k_conv(x) | ||
proj_k_h = proj_k.transpose([0, 3, 1, 2]).reshape([b * w, -1, h]) | ||
proj_k_w = proj_k.transpose([0, 2, 1, 3]).reshape([b * h, -1, w]) | ||
|
||
proj_v = self.v_conv(x) | ||
proj_v_h = proj_v.transpose([0, 3, 1, 2]).reshape([b * w, -1, h]) | ||
proj_v_w = proj_v.transpose([0, 2, 1, 3]).reshape([b * h, -1, w]) | ||
|
||
energy_h = (paddle.bmm(proj_q_h, proj_k_h) + self.Inf(b, h, w)).reshape( | ||
[b, w, h, h]).transpose([0, 2, 1, 3]) | ||
energy_w = paddle.bmm(proj_q_w, proj_k_w).reshape([b, h, w, w]) | ||
concate = self.softmax(paddle.concat([energy_h, energy_w], axis=3)) | ||
|
||
attn_h = concate[:, :, :, 0:h].transpose([0, 2, 1, 3]).reshape( | ||
[b * w, h, h]) | ||
attn_w = concate[:, :, :, h:h + w].reshape([b * h, w, w]) | ||
out_h = paddle.bmm(proj_v_h, attn_h.transpose([0, 2, 1])).reshape( | ||
[b, w, -1, h]).transpose([0, 2, 3, 1]) | ||
out_w = paddle.bmm(proj_v_w, attn_w.transpose([0, 2, 1])).reshape( | ||
[b, h, -1, w]).transpose([0, 2, 1, 3]) | ||
return self.gamma * (out_h + out_w) + x | ||
|
||
def Inf(self, B, H, W): | ||
return -paddle.tile( | ||
paddle.diag(paddle.tile(self.inf_tensor, [H]), 0).unsqueeze(0), | ||
[B * W, 1, 1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
_base_: '../_base_/cityscapes_769x769.yml' | ||
|
||
batch_size: 2 | ||
iters: 60000 | ||
|
||
model: | ||
type: CCNet | ||
backbone: | ||
type: ResNet101_vd | ||
output_stride: 8 | ||
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz | ||
backbone_indices: [2, 3] | ||
enable_auxiliary_loss: True | ||
dropout_prob: 0.1 | ||
recurrence: 2 | ||
|
||
loss: | ||
types: | ||
- type: OhemCrossEntropyLoss | ||
- type: CrossEntropyLoss | ||
coef: [1, 0.4] | ||
|
||
lr_scheduler: | ||
type: PolynomialDecay | ||
learning_rate: 0.01 | ||
power: 0.9 | ||
end_lr: 1.0e-4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
===========================train_params=========================== | ||
model_name:ccnet | ||
python:python3 | ||
gpu_list:0|0,1 | ||
Global.use_gpu:null|null | ||
--precision:null | ||
--iters:lite_train_lite_infer=50|lite_train_whole_infer=50|whole_train_whole_infer=1000 | ||
--save_dir: | ||
--batch_size:lite_train_lite_infer=2|lite_train_whole_infer=2|whole_train_whole_infer=3 | ||
--model_path:null | ||
train_model_name:best_model/model.pdparams | ||
train_infer_img_dir:test_tipc/data/cityscapes/cityscapes_val_5.list | ||
null:null | ||
## | ||
trainer:norm | ||
norm_train:train.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml --save_interval 500 --seed 100 --num_workers 8 | ||
pact_train:null | ||
fpgm_train:null | ||
distill_train:null | ||
null:null | ||
null:null | ||
## | ||
===========================eval_params=========================== | ||
eval:val.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml --num_workers 8 | ||
null:null | ||
## | ||
===========================export_params=========================== | ||
--save_dir: | ||
--model_path: | ||
norm_export:export.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml | ||
quant_export:null | ||
fpgm_export:null | ||
distill_export:null | ||
export1:null | ||
export2:null | ||
===========================infer_params=========================== | ||
infer_model:./test_tipc/output/ccnet/model.pdparams | ||
infer_export:export.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml | ||
infer_quant:False | ||
inference:deploy/python/infer.py | ||
--device:cpu|gpu | ||
--enable_mkldnn:True|False | ||
--cpu_threads:1|6 | ||
--batch_size:1 | ||
--use_trt:False | ||
--precision:fp32|int8|fp16 | ||
--config: | ||
--image_path:./test_tipc/data/cityscapes/cityscapes_val_5.list | ||
--save_log_path:null | ||
--benchmark:True | ||
--save_dir: | ||
--model_name:ccnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters