Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
moreOver0 committed Jun 11, 2020
1 parent ec6c02d commit b03976c
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ pip install --upgrade pip
pip install -r requirements.txt
```



## Datasets


[PEMS03](http://pems.dot.ca.gov/?dnode=Clearinghouse&type=station_5min&district_id=3&submit=Submit),
[PEMS04](http://pems.dot.ca.gov/?dnode=Clearinghouse&type=station_5min&district_id=4&submit=Submit),
[PEMS07](http://pems.dot.ca.gov/?dnode=Clearinghouse&type=station_5min&district_id=7&submit=Submit),
Expand All @@ -35,31 +32,35 @@ pip install -r requirements.txt
[ECG5000](http://www.timeseriesclassification.com/description.php?Dataset=ECG5000)
[COVID-19](https://github.com/CSSEGISandData/COVID-19/tree/master)

We can get the raw data through the links above. We evaluate the performance of traffic flow forecasting on PEMS03, PEMS07, PEMS08 and traffic speed forecasting on PEMS04, PEMS-BAY and METR-LA. So we use the traffic flow table of PEMS03, PEMS07, PEMS08 and the traffic speed table of PEMS04, PEMS-BAY and METR-LA as our datasets. We download the solar power data of Alabama (Eastern States) and merge the 5-minute csv files (totally 137 time series) as our Solar dataset. We delete the header and index of Electricity file downloaded from the link above as our Electricity dataset. For COVID-19 dataset, the raw data is under the folder `csse_covid_19_data/csse_covid_19_time_series/` of the above github link. We use `time_series_covid19_confirmed_global.csv` to calculate the daily number of newly confirmed infected people from 1/22/2020 to 5/10/2020. The 25 countries we take into consideration are 'US','Canada','Mexico','Russia','UK','Italy','Germany','France','Belarus ','Brazil','Peru','Ecuador','Chile','India','Turkey','Saudi Arabia','Pakistan','Iran','Singapore','Qatar','Bangladesh','Arab','China','Japan','Korea'. We name each file after the datasets.

The input csv file should have **no header** and its **shape should be `T*N`**, where `N` denotes number of nodes, `T` denotes total number of timestamps.


We can get the raw data through the link above. We evaluate the performance of traffic flow forecasting on PEMS03, PEMS07, PEMS08 and traffic speed forecasting on PEMS04, PEMS-BAY and METR-LA. So we use the traffic flow table of PEMS03, PEMS07, PEMS08 and the traffic speed table of PEMS04, PEMS-BAY and METR-LA as our datasets. We download the solar power data of Alabama (Eastern States) and merge the *5_Min.csv (totally 137 time series) as our Solar dataset. We delete the header and index of Electricity file downloaded from the link above as our Electricity dataset. For COVID-19 dataset, the raw data is at the floder `csse_covid_19_data/csse_covid_19_time_series/` of the above github link. We use `time_series_covid19_confirmed_global.csv` to calculate the daily of newly confirmed people number from 1/22/2020 to 5/10/2020. The 25 countries we used are 'US','Canada','Mexico','Russia','UK','Italy','Germany','France','Belarus ','Brazil','Peru','Ecuador','Chile','India','Turkey','Saudi Arabia','Pakistan','Iran','Singapore','Qatar','Bangladesh','Arab','China','Japan','Korea'. We name each file after the datasets.

The shape of input is `T*N`, where `N` denotes number of nodes, `T` denotes total number of timestamps.

We provide a cleaned version of ECG5000 ([./dataset/ECG_data.csv](./dataset/ECG_data.csv)) for reproduction convenience. The ECG_data.csv is in shape of `5000*140`, where `5000` denotes number of timestamps and `140` denotes total number of nodes. Run command `python main.py` to trigger training and evaluation on ECG_data.csv.
Since complex data cleansing is needed on the above datasets provided in the urls before fed into the StemGNN model, we provide a cleaned version of ECG5000 ([./dataset/ECG_data.csv](./dataset/ECG_data.csv)) for reproduction convenience. The ECG_data.csv is in shape of `5000*140`, where `5000` denotes number of timestamps and `140` denotes total number of nodes. Run command `python main.py` to trigger training and evaluation on ECG_data.csv.

## Training and Evaluation

The training procedure and evaluation procedure are all included in the `main.py`. To train and evaluate on some dataset, run the following command:

For ECG5000 dataset:

```train & evaluate ECG
python main.py --train True --evaluate True --dataset ./dataset/ECG_data.csv --output_dir ./output/ECG_data --n_route 140 --n_his 12 --n_pred 3
```train & evaluate
python main.py --train True --evaluate True --dataset <path to csv file> --output_dir <path to output directory> --n_route 140 --n_his 12 --n_pred 3 --scalar z_score --train_length 6 --validate_length 2 --test_length 2
```

We set the flag 'train' to 'True' so that we can train our model and set the flag 'evaluate' to 'True' so that we can evaluate our model after we save the model to the flag 'output_dir'. StemGNN reads data from 'dataset'. Besides, the flag 'n_route' means the number of time series, the 'n_his' is our sliding window and the 'n_pred' is the horizon.


### Complete settings for all datasets

**Table 1** (Settings for datasets)
| Parameter name | Description of parameter |
| --- | --- |
| train | whether to enable training |
| evaluate | whether to enable evaluation |
| dataset | path to the input csv file |
| output_dir | output directory, will store models and tensorboard information in this directory, and evaluation restores model from this directory |
| scalar | method to normalize |
| train_length | length of training data |
| validate_length | length of validation data |
| test_length | length of testing data |
| n_route | number of time series, i.e. number of nodes |
| n_his | length of sliding window |
| n_pred | predict horizon |

**Table 1** Configurations for all datasets
| Dataset | train | evaluate | n_route | n_his | n_pred | scalar |
| ----- | ---- | ---- |---- |---- |---- | --- |
| PEMS03 | True | True | 358 | 12 | 12 | z_score |
Expand All @@ -73,12 +74,6 @@ We set the flag 'train' to 'True' so that we can train our model and set the fla
| ECG5000| True | True | 140 | 12 | 3 | min_max |
| COVID-19| True | True | 25 | 28 | 28 | z_score |


In this code repo, we have processed ECG5000 as the sample dataset, the input is stored at `./dataset/ECG_data.csv` and the output of StemGNN will be stored at `./output/ECG_data`.




## Results

Our model achieves the following performance on the 9 datasets included in the code repo:
Expand Down

0 comments on commit b03976c

Please sign in to comment.