forked from thuml/Autoformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReformer.py
55 lines (49 loc) · 2.4 KB
/
Reformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import ReformerLayer
from layers.Embed import DataEmbedding
class Model(nn.Module):
"""
Reformer with O(LlogL) complexity
- It is notable that Reformer is not proposed for time series forecasting, in that it cannot accomplish the cross attention.
- Here is only one adaption in BERT-style, other possible implementations can also be acceptable.
- The hyper-parameters, such as bucket_size and n_hashes, need to be further tuned.
The official repo of Reformer (https://github.com/lucidrains/reformer-pytorch) can be very helpful, if you have any questiones.
"""
def __init__(self, configs):
super(Model, self).__init__()
self.pred_len = configs.pred_len
self.pred_len = configs.pred_len
self.output_attention = configs.output_attention
# Embedding
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
ReformerLayer(None, configs.d_model, configs.n_heads, bucket_size=configs.bucket_size,
n_hashes=configs.n_hashes),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
# add placeholder
x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
x_mark_enc = torch.cat([x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
# Reformer: encoder only
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
enc_out = self.projection(enc_out)
if self.output_attention:
return enc_out[:, -self.pred_len:, :], attns
else:
return enc_out[:, -self.pred_len:, :] # [B, L, D]