Skip to content
/ STEP Public
forked from GestaltCogTeam/STEP

Code for our SIGKDD'22 paper Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting.

License

Notifications You must be signed in to change notification settings

NoahWangc/STEP

 
 

Repository files navigation

Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting

EasyTorch LICENSE

Code for our SIGKDD'22 paper: "Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting".

The code is developed with EasyTorch, an easy-to-use and powerful open source neural network training framework.

TheTable

All the training logs of the pre-training stage and the forecasting stage can be found in train_logs/.

Multivariate Time Series (MTS) forecasting plays a vital role in a wide range of applications. Recently, Spatial-Temporal Graph Neural Networks (STGNNs) have become increasingly popular MTS forecasting methods. STGNNs jointly model the spatial and temporal patterns of MTS through graph neural networks and sequential models, significantly improving the prediction accuracy. But limited by model complexity, most STGNNs only consider short-term historical MTS data, such as data over the past one hour. However, the patterns of time series and the dependencies between them (i.e., the temporal and spatial patterns) need to be analyzed based on long-term historical MTS data. To address this issue, we propose a novel framework, in which STGNN is Enhanced by a scalable time series Pre-training model (STEP). Specifically, we design a pre-training model to efficiently learn temporal patterns from very long-term history time series (e.g., the past two weeks) and generate segment-level representations. These representations provide contextual information for short-term time series input to STGNNs and facilitate modeling dependencies between time series. Experiments on three public real-world datasets demonstrate that our framework is capable of significantly enhancing downstream STGNNs, and our pre-training model aptly captures temporal patterns.

1. Table of Contents

config          -->     Training configs and model configs for each dataset
dataloader      -->     MTS dataset
easytorch       -->     EasyTorch
model           -->     Model architecture
checkpoints     -->     Saving the checkpoints according to md5 of the configuration file
datasets        -->     Raw datasets and preprocessed data
train_logs      -->     Our train logs.
TSFormer_CKPT   -->     Our checkpoints.

2. Requirements

pip install -r requirements.txt

3. Data Preparation

3.1 Download Data

Download data from link Google Drive or BaiduYun to the code root directory.

Then, unzip data by:

unzip TSFormer_CKPT.zip
mkdir datasets
unzip raw_data.zip -d datasets
unzip sensor_graph.zip -d datasets 
rm *.zip

TSFormer_CKPT/ contains the pre-trained model for each dataset.

You can also find all the training logs of the pre-training stage and forecasting stage in training_logs/.

3.2 Preprocess Data

python datasets/raw_data/$DATASET_NAME/generate_data.py

Replace $DATASET_NAME with one of METR-LA, PEMS-BAY, PEMS04.

The processed data is placed in datasets/$DATASET_NAME.

4. Train STEP based on a Pre-trained TSFormer

python main.py --cfg='config/$DATASET/forecasting.py' --gpu='0, 1'
# python main.py --cfg='config/METR-LA/forecasting.py' --gpu='0, 1'
# python main.py --cfg='config/PEMS-BAY/forecasting.py' --gpu='0, 1'
# python main.py --cfg='config/PEMS04/forecasting.py' --gpu='0, 1'

Replace $DATASET_NAME with one of METR-LA, PEMS-BAY, PEMS04 as shown in the code above.

Configuration file config/$DATASET_NAME/forecasting.py describes the forecasting configurations.

We use 2 GPU for forecasting stage as default, edit GPU_NUM property in the configuration file and --gpu in the command line to run on your own hardware.

Note that different GPU numbers lead to different real batch sizes, affecting the learning rate setting and the forecasting accuracy.

Our training logs are shown in train_logs/Backend_metr.log, train_logs/Backend_pems04.log, and train_logs/Backend_pemsbay.log.

5. Train STEP from Scratch

5.1 Pre-training Stage

python main.py --cfg='config/$DATASET/pretraining.py' --gpu='0'
# python main.py --cfg='config/METR-LA/pretraining.py' --gpu='0'
# python main.py --cfg='config/PEMS-BAY/pretraining.py' --gpu='0, 1, 2, 3, 4, 5, 6, 7'
# python main.py --cfg='config/PEMS04/pretraining.py' --gpu='0, 1'

Replace $DATASET_NAME with one of METR-LA, PEMS-BAY, PEMS04 as shown in the code above.

Configuration file config/$DATASET_NAME/pretraining.py describes the pre-training configurations.

Edit the BATCH_SIZE and GPU_NUM in the configuration file and --gpu in the command line to run on your own hardware.

5.2 Forecasting Stage

Move your pre-trained model checkpoints to TSFormer_CKPT/. For example:

cp checkpoints/TSFormer_200/9b4b52e25a30aabd21dc1c9429063196/TSFormer_180.pt TSFormer_CKPT/TSFormer_PEMS-BAY.pt
cp checkpoints/TSFormer_200/fac3814778135a6d46063e3cab20257c/TSFormer_147.pt TSFormer_CKPT/TSFormer_PEMS04.pt
cp checkpoints/TSFormer_200/3de38a467aef981dd6f24127b6fb5f50/TSFormer_030.pt TSFormer_CKPT/TSFormer_METR-LA.pt

Then train the downstream STGNN (Graph WaveNet) like in section 4.

6. Performance and Visualization

TheTable

Visualization

7. More Related Works

8. Citing

If you find this repository useful for your work, please consider citing it as follows:

@inproceedings{DBLP:conf/kdd/ShaoZWX22,
  author    = {Zezhi Shao and
               Zhao Zhang and
               Fei Wang and
               Yongjun Xu},
  title     = {Pre-training Enhanced Spatial-temporal Graph Neural Network for Multivariate
               Time Series Forecasting},
  booktitle = {{KDD} '22: The 28th {ACM} {SIGKDD} Conference on Knowledge Discovery
               and Data Mining, Washington, DC, USA, August 14 - 18, 2022},
  pages     = {1567--1577},
  publisher = {{ACM}},
  year      = {2022}
}

About

Code for our SIGKDD'22 paper Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%