-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_EmbedLayer.py
70 lines (50 loc) · 2.24 KB
/
model_EmbedLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
"""require input, B L C"""
"""out BCL"""
x = self.tokenConv(x.permute(0, 2, 1))
return x
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, dropout=0.05):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_Value(nn.Module):
def __init__(self, c_in, d_model, dropout=0.05):
super(DataEmbedding_Value, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.value_embedding(x)
return self.dropout(x)