Skip to content

Commit

Permalink
[Feature] Add CCNet (PaddlePaddle#2005)
Browse files Browse the repository at this point in the history
  • Loading branch information
justld authored Apr 15, 2022
1 parent 6c6dcf9 commit 21be150
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 2 deletions.
13 changes: 13 additions & 0 deletions configs/ccnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# CCNet: Criss-cross attention for semantic segmentation

## Reference

> Zilong Huang, Xinggang Wang, Yunchao Wei, Lichao Huang, Humphrey Shi, Wenyu Liu, Thomas S. Huang. "CCNet: Criss-cross attention for semantic segmentation." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.
## Performance

### Cityscapes

| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
|-|-|-|-|-|-|-|-|
|CCNet|ResNet101_OS8|769x769|60000|80.95%|81.23%|81.32%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/ccnet_resnet101_os8_cityscapes_769x769_60k/model.pdparams)\|[log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/ccnet_resnet101_os8_cityscapes_769x769_60k/train.log)\|[vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=6828616e27a1e15f1442beb3b4834048)|
27 changes: 27 additions & 0 deletions configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_: '../_base_/cityscapes_769x769.yml'

batch_size: 2
iters: 60000

model:
type: CCNet
backbone:
type: ResNet101_vd
output_stride: 8
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz
backbone_indices: [2, 3]
enable_auxiliary_loss: True
dropout_prob: 0.1
recurrence: 2

loss:
types:
- type: OhemCrossEntropyLoss
- type: CrossEntropyLoss
coef: [1, 0.4]

lr_scheduler:
type: PolynomialDecay
learning_rate: 0.01
power: 0.9
end_lr: 1.0e-4
1 change: 1 addition & 0 deletions paddleseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from .pfpnnet import PFPNNet
from .glore import GloRe
from .ddrnet import DDRNet_23
from .ccnet import CCNet
174 changes: 174 additions & 0 deletions paddleseg/models/ccnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddleseg.cvlibs import manager
from paddleseg.models import layers
from paddleseg.utils import utils


@manager.MODELS.add_component
class CCNet(nn.Layer):
"""
The CCNet implementation based on PaddlePaddle.
The original article refers to
Zilong Huang, et al. "CCNet: Criss-Cross Attention for Semantic Segmentation"
(https://arxiv.org/abs/1811.11721)
Args:
num_classes (int): The unique number of target classes.
backbone (paddle.nn.Layer): Backbone network, currently support Resnet18_vd/Resnet34_vd/Resnet50_vd/Resnet101_vd.
backbone_indices (tuple, list, optional): Two values in the tuple indicate the indices of output of backbone. Default: (2, 3).
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
dropout_prob (float, optional): The probability of dropout. Default: 0.0.
recurrence (int, optional): The number of recurrent operations. Defautl: 1.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""

def __init__(self,
num_classes,
backbone,
backbone_indices=(2, 3),
enable_auxiliary_loss=True,
dropout_prob=0.0,
recurrence=1,
align_corners=False,
pretrained=None):
super().__init__()
self.enable_auxiliary_loss = enable_auxiliary_loss
self.recurrence = recurrence
self.align_corners = align_corners

self.backbone = backbone
self.backbone_indices = backbone_indices
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]

if enable_auxiliary_loss:
self.aux_head = layers.AuxLayer(
backbone_channels[0],
512,
num_classes,
dropout_prob=dropout_prob)
self.head = RCCAModule(
backbone_channels[1],
512,
num_classes,
dropout_prob=dropout_prob,
recurrence=recurrence)
self.pretrained = pretrained

def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)

def forward(self, x):
feat_list = self.backbone(x)
logit_list = []
output = self.head(feat_list[self.backbone_indices[-1]])
logit_list.append(output)
if self.training and self.enable_auxiliary_loss:
aux_out = self.aux_head(feat_list[self.backbone_indices[-2]])
logit_list.append(aux_out)
return [
F.interpolate(
logit,
paddle.shape(x)[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list
]


class RCCAModule(nn.Layer):
def __init__(self,
in_channels,
out_channels,
num_classes,
dropout_prob=0.1,
recurrence=1):
super().__init__()
inter_channels = in_channels // 4
self.recurrence = recurrence
self.conva = layers.ConvBNLeakyReLU(
in_channels, inter_channels, 3, padding=1, bias_attr=False)
self.cca = CrissCrossAttention(inter_channels)
self.convb = layers.ConvBNLeakyReLU(
inter_channels, inter_channels, 3, padding=1, bias_attr=False)
self.out = layers.AuxLayer(
in_channels + inter_channels,
out_channels,
num_classes,
dropout_prob=dropout_prob)

def forward(self, x):
feat = self.conva(x)
for i in range(self.recurrence):
feat = self.cca(feat)
feat = self.convb(feat)
output = self.out(paddle.concat([x, feat], axis=1))
return output


class CrissCrossAttention(nn.Layer):
def __init__(self, in_channels):
super().__init__()
self.q_conv = nn.Conv2D(in_channels, in_channels // 8, kernel_size=1)
self.k_conv = nn.Conv2D(in_channels, in_channels // 8, kernel_size=1)
self.v_conv = nn.Conv2D(in_channels, in_channels, kernel_size=1)
self.softmax = nn.Softmax(axis=3)
self.gamma = self.create_parameter(
shape=(1, ), default_initializer=nn.initializer.Constant(0))
self.inf_tensor = paddle.full(shape=(1, ), fill_value=float('inf'))

def forward(self, x):
b, c, h, w = paddle.shape(x)
proj_q = self.q_conv(x)
proj_q_h = proj_q.transpose([0, 3, 1, 2]).reshape(
[b * w, -1, h]).transpose([0, 2, 1])
proj_q_w = proj_q.transpose([0, 2, 1, 3]).reshape(
[b * h, -1, w]).transpose([0, 2, 1])

proj_k = self.k_conv(x)
proj_k_h = proj_k.transpose([0, 3, 1, 2]).reshape([b * w, -1, h])
proj_k_w = proj_k.transpose([0, 2, 1, 3]).reshape([b * h, -1, w])

proj_v = self.v_conv(x)
proj_v_h = proj_v.transpose([0, 3, 1, 2]).reshape([b * w, -1, h])
proj_v_w = proj_v.transpose([0, 2, 1, 3]).reshape([b * h, -1, w])

energy_h = (paddle.bmm(proj_q_h, proj_k_h) + self.Inf(b, h, w)).reshape(
[b, w, h, h]).transpose([0, 2, 1, 3])
energy_w = paddle.bmm(proj_q_w, proj_k_w).reshape([b, h, w, w])
concate = self.softmax(paddle.concat([energy_h, energy_w], axis=3))

attn_h = concate[:, :, :, 0:h].transpose([0, 2, 1, 3]).reshape(
[b * w, h, h])
attn_w = concate[:, :, :, h:h + w].reshape([b * h, w, w])
out_h = paddle.bmm(proj_v_h, attn_h.transpose([0, 2, 1])).reshape(
[b, w, -1, h]).transpose([0, 2, 3, 1])
out_w = paddle.bmm(proj_v_w, attn_w.transpose([0, 2, 1])).reshape(
[b, h, -1, w]).transpose([0, 2, 1, 3])
return self.gamma * (out_h + out_w) + x

def Inf(self, B, H, W):
return -paddle.tile(
paddle.diag(paddle.tile(self.inf_tensor, [H]), 0).unsqueeze(0),
[B * W, 1, 1])
2 changes: 1 addition & 1 deletion paddleseg/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .layer_libs import ConvBNReLU, ConvBN, SeparableConvBNReLU, DepthwiseConvBN, AuxLayer, SyncBatchNorm, JPU, ConvBNPReLU, ConvBNAct
from .layer_libs import ConvBNReLU, ConvBN, SeparableConvBNReLU, DepthwiseConvBN, AuxLayer, SyncBatchNorm, JPU, ConvBNPReLU, ConvBNAct, ConvBNLeakyReLU
from .activation import Activation
from .pyramid_pool import ASPPModule, PPModule
from .attention import AttentionBlock
Expand Down
26 changes: 26 additions & 0 deletions paddleseg/models/layers/layer_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,29 @@ def forward(self, x):
x = self._batch_norm(x)
x = self._prelu(x)
return x


class ConvBNLeakyReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()

self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)

if 'data_format' in kwargs:
data_format = kwargs['data_format']
else:
data_format = 'NCHW'
self._batch_norm = SyncBatchNorm(out_channels, data_format=data_format)
self._relu = layers.Activation("leakyrelu")

def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
x = self._relu(x)
return x
1 change: 1 addition & 0 deletions test_tipc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
| ENet | ENet | 支持 | - | - | - |
| FastSCNN | FastSCNN | 支持 | - | - | - |
| DDRNet | DDRNet_23 | 支持 | - | - | - |
| CCNet | CCNet | 支持 | - | - | - |


## 3. 测试工具简介
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_: '../_base_/cityscapes_769x769.yml'

batch_size: 2
iters: 60000

model:
type: CCNet
backbone:
type: ResNet101_vd
output_stride: 8
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz
backbone_indices: [2, 3]
enable_auxiliary_loss: True
dropout_prob: 0.1
recurrence: 2

loss:
types:
- type: OhemCrossEntropyLoss
- type: CrossEntropyLoss
coef: [1, 0.4]

lr_scheduler:
type: PolynomialDecay
learning_rate: 0.01
power: 0.9
end_lr: 1.0e-4
52 changes: 52 additions & 0 deletions test_tipc/configs/ccnet/train_infer_python.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
===========================train_params===========================
model_name:ccnet
python:python3
gpu_list:0|0,1
Global.use_gpu:null|null
--precision:null
--iters:lite_train_lite_infer=50|lite_train_whole_infer=50|whole_train_whole_infer=1000
--save_dir:
--batch_size:lite_train_lite_infer=2|lite_train_whole_infer=2|whole_train_whole_infer=3
--model_path:null
train_model_name:best_model/model.pdparams
train_infer_img_dir:test_tipc/data/cityscapes/cityscapes_val_5.list
null:null
##
trainer:norm
norm_train:train.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml --save_interval 500 --seed 100 --num_workers 8
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:val.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml --num_workers 8
null:null
##
===========================export_params===========================
--save_dir:
--model_path:
norm_export:export.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:./test_tipc/output/ccnet/model.pdparams
infer_export:export.py --config test_tipc/configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml
infer_quant:False
inference:deploy/python/infer.py
--device:cpu|gpu
--enable_mkldnn:True|False
--cpu_threads:1|6
--batch_size:1
--use_trt:False
--precision:fp32|int8|fp16
--config:
--image_path:./test_tipc/data/cityscapes/cityscapes_val_5.list
--save_log_path:null
--benchmark:True
--save_dir:
--model_name:ccnet
1 change: 1 addition & 0 deletions test_tipc/docs/test_train_inference_python.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho
| ENet | ENet | 正常训练 | 正常训练 | | |
| FastSCNN | FastSCNN | 正常训练 | 正常训练 | | |
| DDRNet | DDRNet_23 | 正常训练 | 正常训练 | | |
| CCNet | CCNet | 正常训练 | 正常训练 | | |


- 预测相关:基于训练是否使用量化,可以将训练产出的模型可以分为`正常模型``量化模型`,这两类模型对应的预测功能汇总如下,
Expand Down
2 changes: 1 addition & 1 deletion test_tipc/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ else
fi

models=("enet" "bisenetv2" "ocrnet_hrnetw18" "ocrnet_hrnetw48" "deeplabv3p_resnet50_cityscapes" \
"fastscnn" "fcn_hrnetw18" "pp_liteseg_stdc1" "pp_liteseg_stdc2")
"fastscnn" "fcn_hrnetw18" "pp_liteseg_stdc1" "pp_liteseg_stdc2" "ccnet")
if [ $(contains "${models[@]}" "${model_name}") == "y" ]; then
cp ./test_tipc/data/cityscapes_val_5.list ./test_tipc/data/cityscapes
fi

0 comments on commit 21be150

Please sign in to comment.