diff --git a/README.md b/README.md index 2cf47bb..884a728 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,8 @@ $ pip install fightingcv - [2. ResNeXt Usage](#2-ResNeXt-Usage) + - [3. MobileViT Usage](#3-MobileViT-Usage) + - [MLP Series](#mlp-series) - [1. RepMLP Usage](#1-RepMLP-Usage) @@ -984,6 +986,9 @@ if __name__ == '__main__': - Pytorch implementation of ["Aggregated Residual Transformations for Deep Neural Networks---CVPR2017"](https://arxiv.org/abs/1611.05431v2) +- Pytorch implementation of [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907) + + ### 1. ResNet Usage #### 1.1. Paper ["Deep Residual Learning for Image Recognition---CVPR2016 Best Paper"](https://arxiv.org/pdf/1512.03385.pdf) @@ -1031,11 +1036,47 @@ if __name__ == '__main__': print(out.shape) - ``` +### 3. MobileViT Usage +#### 3.1. Paper + +[MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05](https://arxiv.org/abs/2103.02907) + +#### 3.2. Overview +![](./fightingcv/img/mobileViT) + +#### 3.3. Usage Code +```python + +from fightingcv.backbone.MobileViT import * +import torch +from torch import nn +from torch.nn import functional as F + +if __name__ == '__main__': + input=torch.randn(1,3,224,224) + + ### mobilevit_xxs + mvit_xxs=mobilevit_xxs() + out=mvit_xxs(input) + print(out.shape) + + ### mobilevit_xs + mvit_xs=mobilevit_xs() + out=mvit_xs(input) + print(out.shape) + + + ### mobilevit_s + mvit_s=mobilevit_s() + out=mvit_s(input) + print(out.shape) + +``` + diff --git a/fightingcv/backbone/MobileViT.py b/fightingcv/backbone/MobileViT.py new file mode 100644 index 0000000..bca650e --- /dev/null +++ b/fightingcv/backbone/MobileViT.py @@ -0,0 +1,237 @@ +from torch import nn +import torch +from torch.nn.modules import conv +from torch.nn.modules.conv import Conv2d +from einops import rearrange + + + +def conv_bn(inp,oup,kernel_size=3,stride=1): + return nn.Sequential( + nn.Conv2d(inp,oup,kernel_size=kernel_size,stride=stride,padding=kernel_size//2), + nn.BatchNorm2d(oup), + nn.ReLU() + ) + +class PreNorm(nn.Module): + def __init__(self,dim,fn): + super().__init__() + self.ln=nn.LayerNorm(dim) + self.fn=fn + def forward(self,x,**kwargs): + return self.fn(self.ln(x),**kwargs) + +class FeedForward(nn.Module): + def __init__(self,dim,mlp_dim,dropout) : + super().__init__() + self.net=nn.Sequential( + nn.Linear(dim,mlp_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(mlp_dim,dim), + nn.Dropout(dropout) + ) + def forward(self,x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self,dim,heads,head_dim,dropout): + super().__init__() + inner_dim=heads*head_dim + project_out=not(heads==1 and head_dim==dim) + + self.heads=heads + self.scale=head_dim**-0.5 + + self.attend=nn.Softmax(dim=-1) + self.to_qkv=nn.Linear(dim,inner_dim*3,bias=False) + + self.to_out=nn.Sequential( + nn.Linear(inner_dim,dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self,x): + qkv=self.to_qkv(x).chunk(3,dim=-1) + q,k,v=map(lambda t:rearrange(t,'b p n (h d) -> b p h n d',h=self.heads),qkv) + dots=torch.matmul(q,k.transpose(-1,-2))*self.scale + attn=self.attend(dots) + out=torch.matmul(attn,v) + out=rearrange(out,'b p h n d -> b p n (h d)') + return self.to_out(out) + + + + + +class Transformer(nn.Module): + def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.): + super().__init__() + self.layers=nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim,Attention(dim,heads,head_dim,dropout)), + PreNorm(dim,FeedForward(dim,mlp_dim,dropout)) + ])) + + + def forward(self,x): + out=x + for att,ffn in self.layers: + out=out+att(out) + out=out+ffn(out) + return out + +class MobileViTAttention(nn.Module): + def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7,depth=3,mlp_dim=1024): + super().__init__() + self.ph,self.pw=patch_size,patch_size + self.conv1=nn.Conv2d(in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2) + self.conv2=nn.Conv2d(in_channel,dim,kernel_size=1) + + self.trans=Transformer(dim=dim,depth=depth,heads=8,head_dim=64,mlp_dim=mlp_dim) + + self.conv3=nn.Conv2d(dim,in_channel,kernel_size=1) + self.conv4=nn.Conv2d(2*in_channel,in_channel,kernel_size=kernel_size,padding=kernel_size//2) + + def forward(self,x): + y=x.clone() #bs,c,h,w + + ## Local Representation + y=self.conv2(self.conv1(x)) #bs,dim,h,w + + ## Global Representation + _,_,h,w=y.shape + y=rearrange(y,'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim',ph=self.ph,pw=self.pw) #bs,h,w,dim + y=self.trans(y) + y=rearrange(y,'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)',ph=self.ph,pw=self.pw,nh=h//self.ph,nw=w//self.pw) #bs,dim,h,w + + ## Fusion + y=self.conv3(y) #bs,dim,h,w + y=torch.cat([x,y],1) #bs,2*dim,h,w + y=self.conv4(y) #bs,c,h,w + + return y + + +class MV2Block(nn.Module): + def __init__(self,inp,out,stride=1,expansion=4): + super().__init__() + self.stride=stride + hidden_dim=inp*expansion + self.use_res_connection=stride==1 and inp==out + + if expansion==1: + self.conv=nn.Sequential( + nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=self.stride,padding=1,groups=hidden_dim,bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(), + nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False), + nn.BatchNorm2d(out) + ) + else: + self.conv=nn.Sequential( + nn.Conv2d(inp,hidden_dim,kernel_size=1,stride=1,bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(), + nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=1,padding=1,groups=hidden_dim,bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(), + nn.Conv2d(hidden_dim,out,kernel_size=1,stride=1,bias=False), + nn.ReLU(), + nn.BatchNorm2d(out) + ) + def forward(self,x): + if(self.use_res_connection): + out=x+self.conv(x) + else: + out=self.conv(x) + return out + +class MobileViT(nn.Module): + def __init__(self,image_size,dims,channels,num_classes,depths=[2,4,3],expansion=4,kernel_size=3,patch_size=2): + super().__init__() + ih,iw=image_size,image_size + ph,pw=patch_size,patch_size + assert iw%pw==0 and ih%ph==0 + + self.conv1=conv_bn(3,channels[0],kernel_size=3,stride=patch_size) + self.mv2=nn.ModuleList([]) + self.m_vits=nn.ModuleList([]) + + + self.mv2.append(MV2Block(channels[0],channels[1],1)) + self.mv2.append(MV2Block(channels[1],channels[2],2)) + self.mv2.append(MV2Block(channels[2],channels[3],1)) + self.mv2.append(MV2Block(channels[2],channels[3],1)) # x2 + self.mv2.append(MV2Block(channels[3],channels[4],2)) + self.m_vits.append(MobileViTAttention(channels[4],dim=dims[0],kernel_size=kernel_size,patch_size=patch_size,depth=depths[0],mlp_dim=int(2*dims[0]))) + self.mv2.append(MV2Block(channels[4],channels[5],2)) + self.m_vits.append(MobileViTAttention(channels[5],dim=dims[1],kernel_size=kernel_size,patch_size=patch_size,depth=depths[1],mlp_dim=int(4*dims[1]))) + self.mv2.append(MV2Block(channels[5],channels[6],2)) + self.m_vits.append(MobileViTAttention(channels[6],dim=dims[2],kernel_size=kernel_size,patch_size=patch_size,depth=depths[2],mlp_dim=int(4*dims[2]))) + + + self.conv2=conv_bn(channels[-2],channels[-1],kernel_size=1) + self.pool=nn.AvgPool2d(image_size//32,1) + self.fc=nn.Linear(channels[-1],num_classes,bias=False) + + def forward(self,x): + y=self.conv1(x) # + y=self.mv2[0](y) + y=self.mv2[1](y) # + y=self.mv2[2](y) + y=self.mv2[3](y) + y=self.mv2[4](y) # + y=self.m_vits[0](y) + + y=self.mv2[5](y) # + y=self.m_vits[1](y) + + y=self.mv2[6](y) # + y=self.m_vits[2](y) + + y=self.conv2(y) + y=self.pool(y).view(y.shape[0],-1) + y=self.fc(y) + return y + +def mobilevit_xxs(): + dims=[60,80,96] + channels= [16, 16, 24, 24, 48, 64, 80, 320] + return MobileViT(224,dims,channels,num_classes=1000) + +def mobilevit_xs(): + dims = [96, 120, 144] + channels = [16, 32, 48, 48, 64, 80, 96, 384] + return MobileViT(224, dims, channels, num_classes=1000) + +def mobilevit_s(): + dims = [144, 192, 240] + channels = [16, 32, 64, 64, 96, 128, 160, 640] + return MobileViT(224, dims, channels, num_classes=1000) + + +def count_paratermeters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +if __name__ == '__main__': + input=torch.randn(1,3,224,224) + + ### mobilevit_xxs + mvit_xxs=mobilevit_xxs() + out=mvit_xxs(input) + print(out.shape) + + ### mobilevit_xs + mvit_xs=mobilevit_xs() + out=mvit_xs(input) + print(out.shape) + + + ### mobilevit_s + mvit_s=mobilevit_s() + out=mvit_s(input) + print(out.shape) + + \ No newline at end of file diff --git a/fightingcv/backbone/__pycache__/MobileViT.cpython-38.pyc b/fightingcv/backbone/__pycache__/MobileViT.cpython-38.pyc new file mode 100644 index 0000000..94866eb Binary files /dev/null and b/fightingcv/backbone/__pycache__/MobileViT.cpython-38.pyc differ diff --git a/fightingcv/backbone/resnest.py b/fightingcv/backbone/resnest.py deleted file mode 100644 index 2de103e..0000000 --- a/fightingcv/backbone/resnest.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch -from torch import nn - - -""" - # in_channel:输入block之前的通道数 - # channel:在block中间处理的时候的通道数(这个值是输出维度的1/4) - # channel * block.expansion:输出的维度 -""" -class BottleNeck(nn.Module): - expansion = 4 - def __init__(self,in_channel,channel,stride=1,downsample=None,radix=2, - cardinality=1, bottleneck_width=64,dropblock_prob=0.0, - avd=False,is_first=False,avd_first=False): - super().__init__() - - ''' - cardinality为ResNeXt中的基数 - radix为cardinality中的组数 - ''' - group_with=int(channel*(bottleneck_width/64))*cardinality - self.dropblock_prob=dropblock_prob - self.radix=radix - self.avd = avd and (stride > 1 or is_first) #self.avd用来判断是否需要降采样 - self.avd_first = avd_first - - ###不同点1:这里第一个卷积的stride为1,空间维度缩小采用了Avgpool - - self.conv1=nn.Conv2d(in_channel,group_with,kernel_size=1,stride=1,bias=False) - self.bn1=nn.BatchNorm2d(channel) - - # if self.avd: - # self.avg= - - - - self.conv2=nn.Conv2d(channel,channel,kernel_size=3,padding=1,bias=False,stride=1) - self.bn2=nn.BatchNorm2d(channel) - - self.conv3=nn.Conv2d(channel,channel*self.expansion,kernel_size=1,stride=1,bias=False) - self.bn3=nn.BatchNorm2d(channel*self.expansion) - - self.relu=nn.ReLU(False) - - self.downsample=downsample - self.stride=stride - - def forward(self,x): - residual=x - - out=self.relu(self.bn1(self.conv1(x))) #bs,c,h,w - out=self.relu(self.bn2(self.conv2(out))) #bs,c,h,w - out=self.relu(self.bn3(self.conv3(out))) #bs,4c,h,w - - if(self.downsample != None): - residual=self.downsample(residual) - - out+=residual - return self.relu(out) - - -class ResNeSt(nn.Module): - def __init__(self,block,layers,num_classes=1000,radix=2, groups=1, bottleneck_width=64,): - super().__init__() - #定义输入模块的维度 - self.in_channel=64 - ### stem layer - self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) - self.bn1=nn.BatchNorm2d(64) - self.relu=nn.ReLU(False) - self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=0,ceil_mode=True) - - ### main layer - self.layer1=self._make_layer(block,64,layers[0]) - self.layer2=self._make_layer(block,128,layers[1],stride=2) - self.layer3=self._make_layer(block,256,layers[2],stride=2) - self.layer4=self._make_layer(block,512,layers[3],stride=2) - - #classifier - self.avgpool=nn.AdaptiveAvgPool2d(1) - self.classifier=nn.Linear(512*block.expansion,num_classes) - self.softmax=nn.Softmax(-1) - - def forward(self,x): - ##stem layer - out=self.relu(self.bn1(self.conv1(x))) #bs,112,112,64 - out=self.maxpool(out) #bs,56,56,64 - - ##layers: - out=self.layer1(out) #bs,56,56,64*4 - out=self.layer2(out) #bs,28,28,128*4 - out=self.layer3(out) #bs,14,14,256*4 - out=self.layer4(out) #bs,7,7,512*4 - - ##classifier - out=self.avgpool(out) #bs,1,1,512*4 - out=out.reshape(out.shape[0],-1) #bs,512*4 - out=self.classifier(out) #bs,1000 - out=self.softmax(out) - - return out - - - - def _make_layer(self,block,channel,blocks,stride=1): - # downsample 主要用来处理H(x)=F(x)+x中F(x)和x的channel维度不匹配问题,即对残差结构的输入进行升维,在做残差相加的时候,必须保证残差的纬度与真正的输出维度(宽、高、以及深度)相同 - # 比如步长!=1 或者 in_channel!=channel&self.expansion - downsample = None - if(stride!=1 or self.in_channel!=channel*block.expansion): - self.downsample=nn.Conv2d(self.in_channel,channel*block.expansion,stride=stride,kernel_size=1,bias=False) - #第一个conv部分,可能需要downsample - layers=[] - layers.append(block(self.in_channel,channel,downsample=self.downsample,stride=stride)) - self.in_channel=channel*block.expansion - for _ in range(1,blocks): - layers.append(block(self.in_channel,channel)) - return nn.Sequential(*layers) - - -def ResNeSt50(num_classes=1000): - return ResNeSt(BottleNeck,[3,4,6,3],num_classes=num_classes,radix=2, groups=1, bottleneck_width=64,) - - -def ResNeSt101(num_classes=1000): - return ResNeSt(BottleNeck,[3,4,23,3],num_classes=num_classes,radix=2, groups=1, bottleneck_width=64,) - - -def ResNeSt152(num_classes=1000): - return ResNeSt(BottleNeck,[3,8,36,3],num_classes=num_classes,radix=2, groups=1, bottleneck_width=64,) - - -if __name__ == '__main__': - input=torch.randn(50,3,224,224) - resnest50=ResNeSt50(1000) - # resnest101=ResNeSt101(1000) - # resnest152=ResNeSt152(1000) - out=resnest50(input) - print(out.shape) - - \ No newline at end of file diff --git a/fightingcv/img/mobileViT.jpg b/fightingcv/img/mobileViT.jpg new file mode 100644 index 0000000..139d134 Binary files /dev/null and b/fightingcv/img/mobileViT.jpg differ diff --git a/main.py b/main.py index b97f57a..dac1fe2 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,27 @@ -from fightingcv.attention.MobileViTAttention import MobileViTAttention +from fightingcv.backbone.MobileViT import * import torch from torch import nn from torch.nn import functional as F if __name__ == '__main__': - m=MobileViTAttention() - input=torch.randn(1,3,49,49) - output=m(input) - print(output.shape) + input=torch.randn(1,3,224,224) + + ### mobilevit_xxs + mvit_xxs=mobilevit_xxs() + out=mvit_xxs(input) + print(out.shape) + + ### mobilevit_xs + mvit_xs=mobilevit_xs() + out=mvit_xs(input) + print(out.shape) + + + ### mobilevit_s + mvit_s=mobilevit_s() + out=mvit_s(input) + print(out.shape) +