Skip to content

Commit

Permalink
add separable conv
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/detectron2#471

Reviewed By: theschnitz

Differential Revision: D24602311

Pulled By: ppwwyyxx

fbshipit-source-id: fa8b59bbbf11a8a0f8f47a97d756f353081157ea
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Oct 29, 2020
1 parent 406f5a8 commit 29e85f9
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 7 deletions.
2 changes: 1 addition & 1 deletion detectron2/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .roi_align_rotated import ROIAlignRotated, roi_align_rotated
from .shape_spec import ShapeSpec
from .wrappers import BatchNorm2d, Conv2d, ConvTranspose2d, cat, interpolate, Linear, nonzero_tuple
from .blocks import CNNBlockBase
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
from .aspp import ASPP

__all__ = [k for k in globals().keys() if not k.startswith("_")]
12 changes: 7 additions & 5 deletions detectron2/layers/aspp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from copy import deepcopy
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
Expand All @@ -19,6 +20,7 @@ def __init__(
in_channels,
out_channels,
dilations,
*,
norm,
activation,
pool_kernel_size=None,
Expand Down Expand Up @@ -60,7 +62,7 @@ def __init__(
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
activation=activation,
activation=deepcopy(activation),
)
)
weight_init.c2_xavier_fill(self.convs[-1])
Expand All @@ -75,7 +77,7 @@ def __init__(
dilation=dilation,
bias=use_bias,
norm=get_norm(norm, out_channels),
activation=activation,
activation=deepcopy(activation),
)
)
weight_init.c2_xavier_fill(self.convs[-1])
Expand All @@ -85,12 +87,12 @@ def __init__(
if pool_kernel_size is None:
image_pooling = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Conv2d(in_channels, out_channels, 1, bias=True, activation=activation),
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
)
else:
image_pooling = nn.Sequential(
nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1),
Conv2d(in_channels, out_channels, 1, bias=True, activation=activation),
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
)
weight_init.c2_xavier_fill(image_pooling[1])
self.convs.append(image_pooling)
Expand All @@ -101,7 +103,7 @@ def __init__(
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
activation=activation,
activation=deepcopy(activation),
)
weight_init.c2_xavier_fill(self.project)

Expand Down
2 changes: 2 additions & 0 deletions detectron2/layers/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def get_norm(norm, out_channels):
Returns:
nn.Module or None: the normalization layer
"""
if norm is None:
return None
if isinstance(norm, str):
if len(norm) == 0:
return None
Expand Down
63 changes: 62 additions & 1 deletion detectron2/layers/blocks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import fvcore.nn.weight_init as weight_init
from torch import nn

from .batch_norm import FrozenBatchNorm2d
from .batch_norm import FrozenBatchNorm2d, get_norm
from .wrappers import Conv2d


"""
CNN building blocks.
"""


class CNNBlockBase(nn.Module):
Expand Down Expand Up @@ -46,3 +53,57 @@ def freeze(self):
p.requires_grad = False
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
return self


class DepthwiseSeparableConv2d(nn.Module):
"""
A kxk depthwise convolution + a 1x1 convolution.
In :paper:`xception`, norm & activation are applied on the second conv.
:paper:`mobilenet` uses norm & activation on both convs.
"""

def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
padding=1,
*,
norm1=None,
activation1=None,
norm2=None,
activation2=None,
):
"""
Args:
norm1, norm2 (str or callable): normalization for the two conv layers.
activation1, activation2 (callable(Tensor) -> Tensor): activation
function for the two conv layers.
"""
super().__init__()
self.depthwise = Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
bias=not norm1,
norm=get_norm(norm1, in_channels),
activation=activation1,
)
self.pointwise = Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=not norm2,
norm=get_norm(norm2, out_channels),
activation=activation2,
)

# default initialization
weight_init.c2_msra_fill(self.depthwise)
weight_init.c2_msra_fill(self.pointwise)

def forward(self, x):
return self.pointwise(self.depthwise(x))
5 changes: 5 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
"lvis": ("1908.03195", "LVIS: A Dataset for Large Vocabulary Instance Segmentation"),
"rrpn": ("1703.01086", "Arbitrary-Oriented Scene Text Detection via Rotation Proposals"),
"imagenet in 1h": ("1706.02677", "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"),
"xception": ("1610.02357", "Xception: Deep Learning with Depthwise Separable Convolutions"),
"mobilenet": (
"1704.04861",
"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications",
),
}


Expand Down
17 changes: 17 additions & 0 deletions tests/layers/test_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-


import unittest
from torch import nn

from detectron2.layers import ASPP, DepthwiseSeparableConv2d


class TestBlocks(unittest.TestCase):
def test_separable_conv(self):
DepthwiseSeparableConv2d(3, 10, norm1="BN", activation1=nn.PReLU())

def test_aspp(self):
m = ASPP(3, 10, [2, 3, 4], norm="", activation=nn.PReLU())
self.assertIsNot(m.convs[0].activation.weight, m.convs[1].activation.weight)
self.assertIsNot(m.convs[0].activation.weight, m.project.activation.weight)

0 comments on commit 29e85f9

Please sign in to comment.