Skip to content

Commit

Permalink
add=add TimeMixer and TSMixer
Browse files Browse the repository at this point in the history
  • Loading branch information
weiming.wsy committed Mar 27, 2024
1 parent 9a875da commit adc8819
Show file tree
Hide file tree
Showing 22 changed files with 2,345 additions and 12 deletions.
6 changes: 4 additions & 2 deletions exp/exp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
Koopa, TiDE, FreTS
Koopa, TiDE, FreTS, TimeMixer, TSMixer


class Exp_Basic(object):
Expand All @@ -27,7 +27,9 @@ def __init__(self, args):
'iTransformer': iTransformer,
'Koopa': Koopa,
'TiDE': TiDE,
'FreTS': FreTS
'FreTS': FreTS,
'TimeMixer': TimeMixer,
'TSMixer': TSMixer
}
self.device = self._acquire_device()
self.model = self._build_model().to(self.device)
Expand Down
68 changes: 68 additions & 0 deletions layers/StandardNorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn


class Normalize(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super(Normalize, self).__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.subtract_last = subtract_last
self.non_norm = non_norm
if self.affine:
self._init_params()

def forward(self, x, mode: str):
if mode == 'norm':
self._get_statistics(x)
x = self._normalize(x)
elif mode == 'denorm':
x = self._denormalize(x)
else:
raise NotImplementedError
return x

def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim - 1))
if self.subtract_last:
self.last = x[:, -1, :].unsqueeze(1)
else:
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()

def _normalize(self, x):
if self.non_norm:
return x
if self.subtract_last:
x = x - self.last
else:
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x

def _denormalize(self, x):
if self.non_norm:
return x
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps * self.eps)
x = x * self.stdev
if self.subtract_last:
x = x + self.last
else:
x = x + self.mean
return x
15 changes: 8 additions & 7 deletions models/FreTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import torch.nn.functional as F
import numpy as np


class Model(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2311.06184.pdf
"""

def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
self.pred_len = configs.seq_len
else:
self.pred_len = configs.pred_len
self.embed_size = 128 #embed_size
self.hidden_size = 256 #hidden_size
self.embed_size = 128 # embed_size
self.hidden_size = 256 # hidden_size
self.pred_len = configs.pred_len
self.feature_size = configs.enc_in #channels
self.feature_size = configs.enc_in # channels
self.seq_len = configs.seq_len
self.channel_independence = configs.channel_independence
self.sparsity_threshold = 0.01
Expand Down Expand Up @@ -50,7 +52,7 @@ def tokenEmb(self, x):
# frequency temporal learner
def MLP_temporal(self, x, B, N, L):
# [B, N, T, D]
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2)
x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho")
return x
Expand All @@ -60,7 +62,7 @@ def MLP_channel(self, x, B, N, L):
# [B, N, T, D]
x = x.permute(0, 2, 1, 3)
# [B, T, N, D]
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension
y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1)
x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho")
x = x.permute(0, 2, 1, 3)
Expand Down Expand Up @@ -100,7 +102,7 @@ def forecast(self, x_enc):
x = self.tokenEmb(x_enc)
bias = x
# [B, N, T, D]
if self.channel_independence == '1':
if self.channel_independence == '0':
x = self.MLP_channel(x, B, N, T)
# [B, N, T, D]
x = self.MLP_temporal(x, B, N, T)
Expand All @@ -114,4 +116,3 @@ def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
return dec_out[:, -self.pred_len:, :] # [B, L, D]
else:
raise ValueError('Only forecast tasks implemented yet')

54 changes: 54 additions & 0 deletions models/TSMixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch.nn as nn


class ResBlock(nn.Module):
def __init__(self, configs):
super(ResBlock, self).__init__()

self.temporal = nn.Sequential(
nn.Linear(configs.seq_len, configs.d_model),
nn.ReLU(),
nn.Linear(configs.d_model, configs.seq_len),
nn.Dropout(configs.dropout)
)

self.channel = nn.Sequential(
nn.Linear(configs.enc_in, configs.d_model),
nn.ReLU(),
nn.Linear(configs.d_model, configs.enc_in),
nn.Dropout(configs.dropout)
)

def forward(self, x):
# x: [B, L, D]
x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2)
x = x + self.channel(x)

return x


class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.layer = configs.e_layers
self.model = nn.ModuleList([ResBlock(configs)
for _ in range(configs.e_layers)])
self.pred_len = configs.pred_len
self.projection = nn.Linear(configs.seq_len, configs.pred_len)

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):

# x: [B, L, D]
for i in range(self.layer):
x_enc = self.model[i](x_enc)
enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2)

return enc_out

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
else:
raise ValueError('Only forecast tasks implemented yet')
Loading

0 comments on commit adc8819

Please sign in to comment.