Skip to content

Commit

Permalink
add TiDE and its ETTh1 script
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoKaku committed Nov 20, 2023
1 parent 51e80b6 commit 966a472
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 1 deletion.
3 changes: 2 additions & 1 deletion exp/exp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
Koopa
Koopa, TiDE


class Exp_Basic(object):
Expand All @@ -26,6 +26,7 @@ def __init__(self, args):
'FiLM': FiLM,
'iTransformer': iTransformer,
'Koopa': Koopa,
'TiDE': TiDE,
}
self.device = self._acquire_device()
self.model = self._build_model().to(self.device)
Expand Down
98 changes: 98 additions & 0 deletions models/TiDE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)



class ResBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
super().__init__()

self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.ln = LayerNorm(output_dim, bias=bias)

def forward(self, x):

out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.dropout(out)
out = out + self.fc3(x)
out = self.ln(out)
return out


#TiDE
class Model(nn.Module):
'''paper: https://arxiv.org/pdf/2304.08424.pdf '''
def __init__(self, configs, bias=True,feature_encode_dim=2):
super(Model, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.seq_len = configs.seq_len #L
self.label_len = configs.label_len
self.pred_len = configs.pred_len #H
self.hidden_dim=configs.d_model
self.res_hidden=configs.d_model
self.encoder_num=configs.e_layers
self.decoder_num=configs.d_layers
self.freq=configs.freq
self.feature_encode_dim=feature_encode_dim
self.decode_dim = configs.c_out
self.temporalDecoderHidden=configs.d_ff
dropout=configs.dropout


freq_map = {'h': 4, 't': 5, 's': 6,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}

self.feature_dim=freq_map[self.freq]


flatten_dim = self.seq_len + (self.seq_len + self.pred_len) * self.feature_encode_dim

self.feature_encoder = ResBlock(self.feature_dim, self.res_hidden, self.feature_encode_dim, dropout, bias)
self.encoders = nn.Sequential(ResBlock(flatten_dim, self.res_hidden, self.hidden_dim, dropout, bias),*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.encoder_num-1)))
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.pred_len, dropout, bias))
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
self.residual_proj = nn.Linear(self.seq_len, self.pred_len, bias=bias)


def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark):

feature = self.feature_encoder(batch_y_mark)
hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1))
decoded = self.decoders(hidden).reshape(hidden.shape[0], self.pred_len, self.decode_dim)
prediction = self.temporalDecoder(torch.cat([feature[:,self.seq_len:], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc)
return prediction

def forward(self, x_enc, x_mark_enc, x_dec, batch_y_mark):

'''x_mark_enc is the exogenous dynamic feature described in the original paper'''
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
batch_y_mark=torch.concat([x_mark_enc,batch_y_mark[:,-self.pred_len:,:]],dim=1)
dec_out = torch.stack([self.forecast(x_enc[:,:,feature], x_mark_enc, x_dec, batch_y_mark) for feature in range(x_enc.shape[-1])],dim=-1)
return dec_out # [B, L, D]

return None





112 changes: 112 additions & 0 deletions scripts/long_term_forecast/ETT_script/TiDE_ETTh1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
export CUDA_VISIBLE_DEVICES=2

model_name=TiDE

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_96 \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 96 \
--e_layers 2 \
--d_layers 2 \
--enc_in 7 \
--dec_in 7 \
--c_out 8 \
--d_model 256 \
--d_ff 256 \
--dropout 0.3 \
--batch_size 512 \
--learning_rate 0.1 \
--patience 5 \
--train_epochs 10 \



python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_192 \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 192 \
--e_layers 2 \
--d_layers 2 \
--enc_in 7 \
--dec_in 7 \
--c_out 8 \
--d_model 256 \
--d_ff 256 \
--dropout 0.3 \
--batch_size 512 \
--learning_rate 0.1 \
--patience 5 \
--train_epochs 10 \




python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_336 \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 336 \
--e_layers 2 \
--d_layers 2 \
--enc_in 7 \
--dec_in 7 \
--c_out 8 \
--d_model 256 \
--d_ff 256 \
--dropout 0.3 \
--batch_size 512 \
--learning_rate 0.1 \
--patience 5 \
--train_epochs 10 \




python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_720 \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 720 \
--e_layers 2 \
--d_layers 2 \
--enc_in 7 \
--dec_in 7 \
--c_out 8 \
--d_model 256 \
--d_ff 256 \
--dropout 0.3 \
--batch_size 512 \
--learning_rate 0.1 \
--patience 5 \
--train_epochs 10 \

3 changes: 3 additions & 0 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import matplotlib.pyplot as plt
import pandas as pd
import math

plt.switch_backend('agg')

Expand All @@ -17,6 +18,8 @@ def adjust_learning_rate(optimizer, epoch, args):
2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
10: 5e-7, 15: 1e-7, 20: 5e-8
}
elif args.lradj == "cosine":
lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))}
if epoch in lr_adjust.keys():
lr = lr_adjust[epoch]
for param_group in optimizer.param_groups:
Expand Down

0 comments on commit 966a472

Please sign in to comment.