-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_DecoderLayer.py
144 lines (116 loc) · 4.29 KB
/
model_DecoderLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch.nn as nn
import torch
class Reconstruct_Decoder_MLP(nn.Module):
'''
embedding single time point to feature space
'''
def __init__(self, in_dim, out_dim): # xdim->256->128->64
super().__init__()
mid_dim = 128
self.fc1 = nn.Linear(in_dim, mid_dim)
self.bn = nn.BatchNorm1d(mid_dim, eps=0.001)
self.act = nn.ReLU()
self.fc2 = nn.Linear(mid_dim, out_dim)
def forward(self, x):
"""
:param x: B * L * C
:return: B * L * X_dim
"""
B, L, C = x.shape
x = x.reshape(-1, C) # B*L*C -> BL * C
x = self.fc2(self.act(self.bn(self.fc1(x))))
x = x.reshape(B, L, -1)
return x
class Projector(nn.Module):
'''
Projector
'''
def __init__(self, in_dim, out_dim): # projector
super().__init__()
# first layer
mid_dim = 256
self.projector = nn.Sequential(
nn.Linear(in_dim, mid_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, out_dim, bias=True))
def forward(self, x):
x = self.projector(x)
return x
class Point_Predictor(nn.Module):
'''
Projector
'''
def __init__(self, in_dim, out_dim): # projector
super().__init__()
# first layer
mid_dim = 128
self.projector = nn.Sequential(
nn.Linear(in_dim, mid_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, out_dim, bias=True))
def forward(self, x):
x = self.projector(x)
return x
class Context_Predictor(nn.Module):
'''
Context_Predictor -> attention
'''
def __init__(self, embed_dim=256, num_heads=4): # projector
super().__init__()
self.context_predictor = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
def forward(self, q, k, v):
attn_output, attn_output_weights = self.context_predictor(q, k, v)
return attn_output, attn_output_weights
attn_output, attn_output_weights = multihead_attn(query, key, value)
class Poolinger(nn.Module):
def __init__(self, kernel_size, stride, mode="avg"): # projector
super().__init__()
if mode == "avg":
self.m = nn.AvgPool1d(kernel_size, stride)
if mode == "max":
self.m = nn.MaxPool1d(kernel_size, stride)
def forward(self, x):
"""Input: B*C*L"""
x = x.permute(0, 2, 1)
x = self.m(x)
x = x.permute(0, 2, 1)
return x
class Reconstruct_Decoder_RNN(nn.Module):
''' Decodes hidden state output by encoder '''
def __init__(self, input_size, hidden_size, out_size, num_layers=1):
'''
: param input_size: the number of features in the input X
: param hidden_size: the number of features in the hidden state h
: param num_layers: number of recurrent layers (i.e., 2 means there are
: 2 stacked LSTMs)
'''
super(Reconstruct_Decoder_RNN, self).__init__()
self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, out_size)
self.act = nn.ReLU()
self.bn = nn.BatchNorm1d(hidden_size, eps=0.001)
def forward(self, x_input, encoder_hidden_states):
rnn_out, hn = self.rnn(x_input, encoder_hidden_states)
rnn_out = self.act(rnn_out)
rnn_out = rnn_out.permute(0, 2, 1)
rnn_out = self.bn(rnn_out)
rnn_out = rnn_out.permute(0, 2, 1)
output = self.linear(rnn_out)
return output
class Context_Decoder(nn.Module):
def __init__(self, input_dim, output_dim, rnn_layers): # xdim->256->128->64
super().__init__()
self.decoder = nn.LSTM(input_size=input_dim, hidden_size=output_dim, num_layers=rnn_layers, batch_first=True)
def forward(self, x):
z = self.decoder(x)
return z
class Predict_Decoder(nn.Module):
def __init__(self, in_dim, out_dim): # xdim->256->128->64
super().__init__()
self.fc1 = nn.Linear(in_dim, 64)
self.bn = nn.BatchNorm1d(64, eps=0.001)
self.act = nn.ReLU()
self.fc2 = nn.Linear(64, out_dim)
def forward(self, x):
x = self.fc2(self.act(self.bn(self.fc1(x))))
return x