Skip to content

Commit

Permalink
MobileVitAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
xmu-xiaoma666 committed Oct 9, 2021
1 parent 17a8c3d commit e097f3b
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions fightingcv/attention/MobileViTAttention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,47 @@
from torch import nn
import torch
from torch.nn.modules import transformer
from torch.nn.modules.transformer import Transformer
from einops import rearrange


class PreNorm(nn.Module):
def __init__(self,dim,fn):
super().__init__()
self.ln=nn.LayerNorm(dim)
self.fn=fn
def forward(self,x,**kwargs):
return self.fn(self.ln(x),**kwargs)

class FeedForward(nn.Module):
def __init__(self,dim,mlp_dim,dropout) :
super().__init__()
self.net=nn.Sequetial(
nn.Lineqar(dim,mlp_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim,dim),
nn.Dropout(dropout)
)
def forward(self,x):
return self.net(x)

class Transformer(nn.Module):
def __init__(self,dim,depth,heads,head_dim,mlp_dim,dropout=0.):
super().__init__()
self.layers=nn.ModuleList([])
for _ in nn.range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim,Attention(dim,heads,head_dim,dropout)),
PreNorm(dim,FeedForward(dim,mlp_dim,dropout))
]))


def forward(self,x):
out=x
for att,ffn in self.layers:
out=out+att(out)
out=out+ffn(out)
return out

class MobileViTAttention(nn.Module):
def __init__(self,in_channel=3,dim=512,kernel_size=3,patch_size=7):
super().__init__()
Expand All @@ -25,7 +63,7 @@ def forward(self,x):
## Global Representation
_,_,h,w=y.shape
y=rearrange(y,'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim',ph=self.ph,pw=self.pw) #bs,h,w,dim
y=self.transformer(y)
y=self.trans(y)
y=rearrange(y,'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)',ph=self.ph,pw=self.pw,nh=h//self.ph,nw=w//self.pw) #bs,dim,h,w

## Fusion
Expand Down

0 comments on commit e097f3b

Please sign in to comment.