forked from thuml/Time-Series-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main'
- Loading branch information
Showing
32 changed files
with
2,679 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,24 +3,30 @@ TSlib is an open-source library for deep learning researchers, especially for de | |
|
||
We provide a neat code base to evaluate advanced deep time series models or develop your model, which covers five mainstream tasks: **long- and short-term forecasting, imputation, anomaly detection, and classification.** | ||
|
||
:triangular_flag_on_post:**News** (2024.03) Given the inconsistent look-back length of various papers, we split the long-term forecasting in the leaderboard into two categories: Look-Back-96 and Look-Back-Searching. We recommend researchers read [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2), which includes both settings of the look-back length into experiments for scientific rigor. | ||
|
||
:triangular_flag_on_post:**News** (2023.10) We add an implementation to [iTransformer](https://arxiv.org/abs/2310.06625), which is the state-of-the-art model for long-term forecasting. The official code and complete scripts of iTransformer can be found [here](https://github.com/thuml/iTransformer). | ||
|
||
:triangular_flag_on_post:**News** (2023.09) We added a detailed [tutorial](https://github.com/thuml/Time-Series-Library/blob/main/tutorial/TimesNet_tutorial.ipynb) for [TimesNet](https://openreview.net/pdf?id=ju_Uqw384Oq) and this library, which is quite friendly to beginners of deep time series analysis. | ||
|
||
:triangular_flag_on_post:**News** (2023.02) We release the TSlib as a comprehensive benchmark and code base for time series models, which is extended from our previous GitHub repository [Autoformer](https://github.com/thuml/Autoformer). | ||
|
||
## Leaderboard for Time Series Analysis | ||
|
||
Till October 2023, the top three models for five different tasks are: | ||
Till March 2024, the top three models for five different tasks are: | ||
|
||
| Model<br>Ranking | Long-term<br>Forecasting | Short-term<br>Forecasting | Imputation | Classification | Anomaly<br>Detection | | ||
| ---------------- |---------------------------------------------------| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------------------- | | ||
| 🥇 1st | [iTransformer](https://arxiv.org/abs/2310.06625) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | | ||
| 🥈 2nd | [PatchTST](https://github.com/yuqinie98/PatchTST) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [FEDformer](https://github.com/MAZiqing/FEDformer) | | ||
| 🥉 3rd | [TimesNet](https://arxiv.org/abs/2210.02186) | [FEDformer](https://github.com/MAZiqing/FEDformer) | [Autoformer](https://github.com/thuml/Autoformer) | [Informer](https://github.com/zhouhaoyi/Informer2020) | [Autoformer](https://github.com/thuml/Autoformer) | | ||
| Model<br>Ranking | Long-term<br>Forecasting<br>Look-Back-96 | Long-term<br/>Forecasting<br/>Look-Back-Searching | Short-term<br>Forecasting | Imputation | Classification | Anomaly<br>Detection | | ||
| ---------------- | ----------------------------------------------------- | ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------------------- | | ||
| 🥇 1st | [iTransformer](https://arxiv.org/abs/2310.06625) | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | | ||
| 🥈 2nd | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [PatchTST](https://github.com/yuqinie98/PatchTST) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [FEDformer](https://github.com/MAZiqing/FEDformer) | | ||
| 🥉 3rd | [TimesNet](https://arxiv.org/abs/2210.02186) | [DLinear](https://arxiv.org/pdf/2205.13504.pdf) | [FEDformer](https://github.com/MAZiqing/FEDformer) | [Autoformer](https://github.com/thuml/Autoformer) | [Informer](https://github.com/zhouhaoyi/Informer2020) | [Autoformer](https://github.com/thuml/Autoformer) | | ||
|
||
|
||
**Note: We will keep updating this leaderboard.** If you have proposed advanced and awesome models, you can send us your paper/code link or raise a pull request. We will add them to this repo and update the leaderboard as soon as possible. | ||
|
||
**Compared models of this leaderboard.** ☑ means that their codes have already been included in this repo. | ||
- [x] **TimeMixer** - TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting [[ICLR 2024]](https://openreview.net/pdf?id=7oLshfEIC2) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py). | ||
- [x] **TSMixer** - TSMixer: An All-MLP Architecture for Time Series Forecasting [[arXiv 2023]](https://arxiv.org/pdf/2303.06053.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TSMixer.py) | ||
- [x] **iTransformer** - iTransformer: Inverted Transformers Are Effective for Time Series Forecasting [[ICLR 2024]](https://arxiv.org/abs/2310.06625) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/iTransformer.py). | ||
- [x] **PatchTST** - A Time Series is Worth 64 Words: Long-term Forecasting with Transformers [[ICLR 2023]](https://openreview.net/pdf?id=Jbdc0vTOcol) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/PatchTST.py). | ||
- [x] **TimesNet** - TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis [[ICLR 2023]](https://openreview.net/pdf?id=ju_Uqw384Oq) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py). | ||
|
@@ -38,6 +44,7 @@ Till October 2023, the top three models for five different tasks are: | |
See our latest paper [[TimesNet]](https://arxiv.org/abs/2210.02186) for the comprehensive benchmark. We will release a real-time updated online version soon. | ||
|
||
**Newly added baselines.** We will add them to the leaderboard after a comprehensive evaluation. | ||
- [x] **SegRNN** - SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting [[arXiv 2023]](https://arxiv.org/abs/2308.11200.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/SegRNN.py). | ||
- [x] **Koopa** - Koopa: Learning Non-stationary Time Series Dynamics with Koopman Predictors [[NeurIPS 2023]](https://arxiv.org/pdf/2305.18803.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Koopa.py). | ||
- [x] **FreTS** - Frequency-domain MLPs are More Effective Learners in Time Series Forecasting [[NeurIPS 2023]](https://arxiv.org/pdf/2311.06184.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FreTS.py). | ||
- [x] **TiDE** - Long-term Forecasting with TiDE: Time-series Dense Encoder [[arXiv 2023]](https://arxiv.org/pdf/2304.08424.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TiDE.py). | ||
|
@@ -98,6 +105,7 @@ If you have any questions or suggestions, feel free to contact: | |
|
||
- Haixu Wu ([email protected]) | ||
- Tengge Hu ([email protected]) | ||
- Yong Liu ([email protected]) | ||
- Haoran Zhang ([email protected]) | ||
- Jiawei Guo ([email protected]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Normalize(nn.Module): | ||
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): | ||
""" | ||
:param num_features: the number of features or channels | ||
:param eps: a value added for numerical stability | ||
:param affine: if True, RevIN has learnable affine parameters | ||
""" | ||
super(Normalize, self).__init__() | ||
self.num_features = num_features | ||
self.eps = eps | ||
self.affine = affine | ||
self.subtract_last = subtract_last | ||
self.non_norm = non_norm | ||
if self.affine: | ||
self._init_params() | ||
|
||
def forward(self, x, mode: str): | ||
if mode == 'norm': | ||
self._get_statistics(x) | ||
x = self._normalize(x) | ||
elif mode == 'denorm': | ||
x = self._denormalize(x) | ||
else: | ||
raise NotImplementedError | ||
return x | ||
|
||
def _init_params(self): | ||
# initialize RevIN params: (C,) | ||
self.affine_weight = nn.Parameter(torch.ones(self.num_features)) | ||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) | ||
|
||
def _get_statistics(self, x): | ||
dim2reduce = tuple(range(1, x.ndim - 1)) | ||
if self.subtract_last: | ||
self.last = x[:, -1, :].unsqueeze(1) | ||
else: | ||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() | ||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() | ||
|
||
def _normalize(self, x): | ||
if self.non_norm: | ||
return x | ||
if self.subtract_last: | ||
x = x - self.last | ||
else: | ||
x = x - self.mean | ||
x = x / self.stdev | ||
if self.affine: | ||
x = x * self.affine_weight | ||
x = x + self.affine_bias | ||
return x | ||
|
||
def _denormalize(self, x): | ||
if self.non_norm: | ||
return x | ||
if self.affine: | ||
x = x - self.affine_bias | ||
x = x / (self.affine_weight + self.eps * self.eps) | ||
x = x * self.stdev | ||
if self.subtract_last: | ||
x = x + self.last | ||
else: | ||
x = x + self.mean | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from layers.Autoformer_EncDec import series_decomp | ||
|
||
|
||
class Model(nn.Module): | ||
""" | ||
Paper link: https://arxiv.org/abs/2308.11200.pdf | ||
""" | ||
|
||
def __init__(self, configs): | ||
super(Model, self).__init__() | ||
|
||
# get parameters | ||
self.seq_len = configs.seq_len | ||
self.enc_in = configs.enc_in | ||
self.d_model = configs.d_model | ||
self.dropout = configs.dropout | ||
|
||
self.task_name = configs.task_name | ||
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': | ||
self.pred_len = configs.seq_len | ||
else: | ||
self.pred_len = configs.pred_len | ||
|
||
self.seg_len = configs.seg_len | ||
self.seg_num_x = self.seq_len // self.seg_len | ||
self.seg_num_y = self.pred_len // self.seg_len | ||
|
||
# building model | ||
self.valueEmbedding = nn.Sequential( | ||
nn.Linear(self.seg_len, self.d_model), | ||
nn.ReLU() | ||
) | ||
self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, | ||
batch_first=True, bidirectional=False) | ||
self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) | ||
self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2)) | ||
|
||
self.predict = nn.Sequential( | ||
nn.Dropout(self.dropout), | ||
nn.Linear(self.d_model, self.seg_len) | ||
) | ||
|
||
if self.task_name == 'classification': | ||
self.act = F.gelu | ||
self.dropout = nn.Dropout(configs.dropout) | ||
self.projection = nn.Linear( | ||
configs.enc_in * configs.seq_len, configs.num_class) | ||
|
||
def encoder(self, x): | ||
# b:batch_size c:channel_size s:seq_len s:seq_len | ||
# d:d_model w:seg_len n:seg_num_x m:seg_num_y | ||
batch_size = x.size(0) | ||
|
||
# normalization and permute b,s,c -> b,c,s | ||
seq_last = x[:, -1:, :].detach() | ||
x = (x - seq_last).permute(0, 2, 1) # b,c,s | ||
|
||
# segment and embedding b,c,s -> bc,n,w -> bc,n,d | ||
x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) | ||
|
||
# encoding | ||
_, hn = self.rnn(x) # bc,n,d 1,bc,d | ||
|
||
# m,d//2 -> 1,m,d//2 -> c,m,d//2 | ||
# c,d//2 -> c,1,d//2 -> c,m,d//2 | ||
# c,m,d -> cm,1,d -> bcm, 1, d | ||
pos_emb = torch.cat([ | ||
self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1), | ||
self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1) | ||
], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) | ||
|
||
_, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d | ||
|
||
# 1,bcm,d -> 1,bcm,w -> b,c,s | ||
y = self.predict(hy).view(-1, self.enc_in, self.pred_len) | ||
|
||
# permute and denorm | ||
y = y.permute(0, 2, 1) + seq_last | ||
return y | ||
|
||
def forecast(self, x_enc): | ||
# Encoder | ||
return self.encoder(x_enc) | ||
|
||
def imputation(self, x_enc): | ||
# Encoder | ||
return self.encoder(x_enc) | ||
|
||
def anomaly_detection(self, x_enc): | ||
# Encoder | ||
return self.encoder(x_enc) | ||
|
||
def classification(self, x_enc): | ||
# Encoder | ||
enc_out = self.encoder(x_enc) | ||
# Output | ||
# (batch_size, seq_length * d_model) | ||
output = enc_out.reshape(enc_out.shape[0], -1) | ||
# (batch_size, num_classes) | ||
output = self.projection(output) | ||
return output | ||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): | ||
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': | ||
dec_out = self.forecast(x_enc) | ||
return dec_out[:, -self.pred_len:, :] # [B, L, D] | ||
if self.task_name == 'imputation': | ||
dec_out = self.imputation(x_enc) | ||
return dec_out # [B, L, D] | ||
if self.task_name == 'anomaly_detection': | ||
dec_out = self.anomaly_detection(x_enc) | ||
return dec_out # [B, L, D] | ||
if self.task_name == 'classification': | ||
dec_out = self.classification(x_enc) | ||
return dec_out # [B, N] | ||
return None |
Oops, something went wrong.