forked from thuml/Time-Series-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Official implementation of iTransformer added
- Loading branch information
1 parent
df10550
commit 096f5d5
Showing
16 changed files
with
1,176 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from layers.Transformer_EncDec import Encoder, EncoderLayer | ||
from layers.SelfAttention_Family import CorrAttention, AttentionLayer | ||
from layers.Embed import DataEmbedding_bnt | ||
import numpy as np | ||
|
||
|
||
class Model(nn.Module): | ||
""" | ||
Vanilla Transformer | ||
with O(L^2) complexity | ||
Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf | ||
""" | ||
|
||
def __init__(self, configs): | ||
super(Model, self).__init__() | ||
self.task_name = configs.task_name | ||
self.seq_len = configs.seq_len | ||
self.pred_len = configs.pred_len | ||
self.output_attention = configs.output_attention | ||
# Embedding | ||
self.enc_embedding = DataEmbedding_bnt(configs.seq_len, configs.d_model, configs.embed, configs.freq, configs.dropout) | ||
# Encoder | ||
self.encoder = Encoder( | ||
[ | ||
EncoderLayer( | ||
AttentionLayer( | ||
CorrAttention(False, configs.factor, attention_dropout=configs.dropout, | ||
output_attention=configs.output_attention), configs.d_model, configs.n_heads), | ||
configs.d_model, | ||
configs.d_ff, | ||
dropout=configs.dropout, | ||
activation=configs.activation | ||
) for l in range(configs.e_layers) | ||
], | ||
norm_layer=torch.nn.LayerNorm(configs.d_model) | ||
) | ||
# Decoder | ||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': | ||
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) | ||
if self.task_name == 'imputation': | ||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) | ||
if self.task_name == 'anomaly_detection': | ||
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) | ||
if self.task_name == 'classification': | ||
self.act = F.gelu | ||
self.dropout = nn.Dropout(configs.dropout) | ||
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class) | ||
|
||
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): | ||
# Normalization from Non-stationary Transformer | ||
means = x_enc.mean(1, keepdim=True).detach() | ||
x_enc = x_enc - means | ||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) | ||
x_enc /= stdev | ||
|
||
_,_,N = x_enc.shape | ||
|
||
# Embedding | ||
enc_out = self.enc_embedding(x_enc.permute(0,2,1), x_mark_enc) | ||
enc_out, attns = self.encoder(enc_out, attn_mask=None) | ||
|
||
dec_out = self.projection(enc_out).permute(0,2,1)[:,:,:N] | ||
# De-Normalization from Non-stationary Transformer | ||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) | ||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) | ||
return dec_out | ||
|
||
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): | ||
# Normalization from Non-stationary Transformer | ||
means = x_enc.mean(1, keepdim=True).detach() | ||
x_enc = x_enc - means | ||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) | ||
x_enc /= stdev | ||
|
||
_,L,N = x_enc.shape | ||
|
||
# Embedding | ||
enc_out = self.enc_embedding(x_enc.permute(0,2,1), x_mark_enc) | ||
enc_out, attns = self.encoder(enc_out, attn_mask=None) | ||
|
||
dec_out = self.projection(enc_out).permute(0,2,1)[:,:,:N] | ||
# De-Normalization from Non-stationary Transformer | ||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) | ||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) | ||
return dec_out | ||
|
||
def anomaly_detection(self, x_enc): | ||
# Normalization from Non-stationary Transformer | ||
means = x_enc.mean(1, keepdim=True).detach() | ||
x_enc = x_enc - means | ||
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) | ||
x_enc /= stdev | ||
|
||
_,L,N = x_enc.shape | ||
|
||
# Embedding | ||
enc_out = self.enc_embedding(x_enc.permute(0,2,1), None) | ||
enc_out, attns = self.encoder(enc_out, attn_mask=None) | ||
|
||
dec_out = self.projection(enc_out).permute(0,2,1)[:,:,:N] | ||
# De-Normalization from Non-stationary Transformer | ||
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) | ||
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) | ||
return dec_out | ||
def classification(self, x_enc, x_mark_enc): | ||
# Embedding | ||
enc_out = self.enc_embedding(x_enc.permute(0, 2, 1), None) | ||
enc_out, attns = self.encoder(enc_out, attn_mask=None) | ||
|
||
# Output | ||
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity | ||
output = self.dropout(output) | ||
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) | ||
output = self.projection(output) # (batch_size, num_classes) | ||
return output | ||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): | ||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': | ||
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) | ||
return dec_out[:, -self.pred_len:, :] # [B, L, D] | ||
if self.task_name == 'imputation': | ||
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) | ||
return dec_out # [B, L, D] | ||
if self.task_name == 'anomaly_detection': | ||
dec_out = self.anomaly_detection(x_enc) | ||
return dec_out # [B, L, D] | ||
if self.task_name == 'classification': | ||
dec_out = self.classification(x_enc, x_mark_enc) | ||
return dec_out # [B, N] | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
export CUDA_VISIBLE_DEVICES=1 | ||
|
||
python -u run.py \ | ||
--task_name anomaly_detection \ | ||
--is_training 1 \ | ||
--root_path ./dataset/MSL \ | ||
--model_id MSL \ | ||
--model iTransformer \ | ||
--data MSL \ | ||
--features M \ | ||
--seq_len 100 \ | ||
--pred_len 0 \ | ||
--d_model 128 \ | ||
--d_ff 128 \ | ||
--e_layers 3 \ | ||
--enc_in 55 \ | ||
--c_out 55 \ | ||
--anomaly_ratio 1 \ | ||
--batch_size 128 \ | ||
--train_epochs 10 |
Oops, something went wrong.