forked from xmu-xiaoma666/External-Attention-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
g_mlp.py
86 lines (59 loc) · 1.97 KB
/
g_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from collections import OrderedDict
import torch
from torch import nn
def exist(x):
return x is not None
class Residual(nn.Module):
def __init__(self,fn):
super().__init__()
self.fn=fn
def forward(self,x):
return self.fn(x)+x
class SpatialGatingUnit(nn.Module):
def __init__(self,dim,len_sen):
super().__init__()
self.ln=nn.LayerNorm(dim)
self.proj=nn.Conv1d(len_sen,len_sen,1)
nn.init.zeros_(self.proj.weight)
nn.init.ones_(self.proj.bias)
def forward(self,x):
res,gate=torch.chunk(x,2,-1) #bs,n,d_ff
###Norm
gate=self.ln(gate) #bs,n,d_ff
###Spatial Proj
gate=self.proj(gate) #bs,n,d_ff
return res*gate
class gMLP(nn.Module):
def __init__(self,num_tokens=None,len_sen=49,dim=512,d_ff=1024,num_layers=6):
super().__init__()
self.num_layers=num_layers
self.embedding=nn.Embedding(num_tokens,dim) if exist(num_tokens) else nn.Identity()
self.gmlp=nn.ModuleList([Residual(nn.Sequential(OrderedDict([
('ln1_%d'%i,nn.LayerNorm(dim)),
('fc1_%d'%i,nn.Linear(dim,d_ff*2)),
('gelu_%d'%i,nn.GELU()),
('sgu_%d'%i,SpatialGatingUnit(d_ff,len_sen)),
('fc2_%d'%i,nn.Linear(d_ff,dim)),
]))) for i in range(num_layers)])
self.to_logits=nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim,num_tokens),
nn.Softmax(-1)
)
def forward(self,x):
#embedding
embeded=self.embedding(x)
#gMLP
y=nn.Sequential(*self.gmlp)(embeded)
#to logits
logits=self.to_logits(y)
return logits
if __name__ == '__main__':
num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)