学习和复现ResNet
原始论文:https://arxiv.org/pdf/1512.03385.pdf
改进论文:https://arxiv.org/pdf/1603.05027.pdf
代码实现部分参考https://blog.csdn.net/frighting_ing/article/details/121324000
一个残差构成块有两条路径 F ( x )和 x,F ( x ) 路径被称为residual mapping,x 路径被称为identity mapping或者 shortcut,⨁ 表示相加,要求 F ( x ) 与 x 的尺寸相同
对于一个神经网络的结构块,假设想要模拟的函数是H(x),理想情况下希望输入x,输出H(x),引入F(x)=H(x)-x,即F(x)+x=H(x),而上图中可见,在网络块中引入了一条支路直接把输入网络块前的x块模拟后的结果相加,正旨在让网络通过模拟F(x)+x来拟合H(x) 而不是直接模拟H(x),在极端情况下,网络块模拟的F(x)为0,至少也是个恒等映射,网络性能不会变差,网络深度得以继续变深。论文还提到,除了保证不变差的情况下,这种结构能够更好的拟合最终函数,举个例子:$H(5)=5.1=F(5)+5$,则$F(5)=0.1$,假设改变对5的映射使输出变化为5.2,则F 需要将映射的输出增加100%,这需要对权重更大幅度的改变(相对于使用传统结构直接拟合H(5)的1-5.2/5.1*100%),可见新的结构对权重调整较大。
用数学公式描述残差块
假设残差块输入x,输出y,有
其中$F(x,{W_i})$表示残差块想要拟合的函数(residual mapping to be learned),比如上文中Figure2里面残差块有两层,则$F=W_2\sigma(W_1x)$,其中$\sigma$表示ReLu,接下来为了确保维数相同,可以给让x通过1x1卷积层,这时公式如下
注意:如果残留块部分(residual mapping)只有一层,公式退化为
反向传播求梯度时可以发现(下图),对恒等映射求x偏导直接为1,而对另一函数求则结果不可能为-1,这样避免了梯度消失
另外,初始论文中提到了残差构建块的两种结构,有bottleneck和无bottleneck,bottleneck结构为右图使用1x1卷积核降维再升维,如同张量流入“瓶颈”
在后续一篇论文中,对于identity mapping有更深入的比较和研究
https://arxiv.org/pdf/1603.05027.pdf
主要是将shortcut的x从原分不动加入residual mapping后的结果,改为运用函数映射并分为多种(如下图)并讨论,实验,比如让h(x)的映射不再是恒等,比如成为$x_{l+1}=\lambda_lx_l+F(x_l,W_l)$,
另外,可以用下图直观的感受以下short cut对网络梯度下降的作用,网络越深,若容易落入局部最优,但short cut让error surface更加平滑,从而更容易到达全局最优
error surface图
且不使用dropout
论文提到,在每次卷积后,激活前都是用BN。
BN让数据满足均值为0,方差为1的分布
论文提到,在vgg的结构中加入shortcut并加深网络深度,结构和详细结构如下
如果取上图中18layer参数,则网络结构如下
可见resnet有很多重复结构,如果按照之前AlexNet,VGG的复现方式,可以一行一行慢慢构建,比较简单,但是我参考了网上流传的博客,都是运用循环结构构建网络,使用一个list,比如layers=[],不断往里放入nn.Module类对象,然后直接转换为nn.Sequential:nn.Sequential(*layers)
,就可以按照nn.Sequential使用了
本代码产生的ResNet对象结构(18 layers)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7, 7] 0
BasicBlock-66 [-1, 512, 7, 7] 0
AvgPool2d-67 [-1, 512, 1, 1] 0
Linear-68 [-1, 2] 1,026
================================================================
Total params: 11,177,538
Trainable params: 11,177,538
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 42.64
Estimated Total Size (MB): 106.00
----------------------------------------------------------------