Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
moreOver0 committed Jun 11, 2020
1 parent 17a917d commit 7d4b2d4
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ pip install -r requirements.txt
[PEMS-BAY](https://github.com/liyaguang/DCRNN),
[Solar](https://www.nrel.gov/grid/solar-power-data.html),
[Electricity](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014),
[ECG5000](http://www.timeseriesclassification.com/description.php?Dataset=ECG5000)
[ECG5000](http://www.timeseriesclassification.com/description.php?Dataset=ECG5000),
[COVID-19](https://github.com/CSSEGISandData/COVID-19/tree/master)

We can get the raw data through the links above. We evaluate the performance of traffic flow forecasting on PEMS03, PEMS07, PEMS08 and traffic speed forecasting on PEMS04, PEMS-BAY and METR-LA. So we use the traffic flow table of PEMS03, PEMS07, PEMS08 and the traffic speed table of PEMS04, PEMS-BAY and METR-LA as our datasets. We download the solar power data of Alabama (Eastern States) and merge the 5-minute csv files (totally 137 time series) as our Solar dataset. We delete the header and index of Electricity file downloaded from the link above as our Electricity dataset. For COVID-19 dataset, the raw data is under the folder `csse_covid_19_data/csse_covid_19_time_series/` of the above github link. We use `time_series_covid19_confirmed_global.csv` to calculate the daily number of newly confirmed infected people from 1/22/2020 to 5/10/2020. The 25 countries we take into consideration are 'US','Canada','Mexico','Russia','UK','Italy','Germany','France','Belarus ','Brazil','Peru','Ecuador','Chile','India','Turkey','Saudi Arabia','Pakistan','Iran','Singapore','Qatar','Bangladesh','Arab','China','Japan','Korea'.

The input csv file should have **no header** and its **shape should be `T*N`**, where `N` denotes number of nodes, `T` denotes total number of timestamps.
The input csv file should contain **no header** and its **shape should be `T*N`**, where `T` denotes total number of timestamps, `N` denotes number of nodes.

Since complex data cleansing is needed on the above datasets provided in the urls before fed into the StemGNN model, we provide a cleaned version of ECG5000 ([./dataset/ECG_data.csv](./dataset/ECG_data.csv)) for reproduction convenience. The ECG_data.csv is in shape of `5000*140`, where `5000` denotes number of timestamps and `140` denotes total number of nodes. Run command `python main.py` to trigger training and evaluation on ECG_data.csv.

Expand All @@ -43,22 +43,23 @@ Since complex data cleansing is needed on the above datasets provided in the url
The training procedure and evaluation procedure are all included in the `main.py`. To train and evaluate on some dataset, run the following command:

```train & evaluate
python main.py --train True --evaluate True --dataset <path to csv file> --output_dir <path to output directory> --n_route 140 --n_his 12 --n_pred 3 --scalar z_score --train_length 6 --validate_length 2 --test_length 2
python main.py --train True --evaluate True --dataset <path to csv file> --output_dir <path to output directory> --n_route <number of nodes> --n_his <length of sliding window> --n_pred <predict horizon> --scalar z_score --train_length 6 --validate_length 2 --test_length 2
```

The detailed descriptions about the parameters are as following:
| Parameter name | Description of parameter |
| --- | --- |
| train | whether to enable training |
| evaluate | whether to enable evaluation |
| train | whether to enable training, default True |
| evaluate | whether to enable evaluation, default True |
| dataset | path to the input csv file |
| output_dir | output directory, will store models and tensorboard information in this directory, and evaluation restores model from this directory |
| scalar | method to normalize |
| train_length | length of training data |
| validate_length | length of validation data |
| test_length | length of testing data |
| scalar | method to normalize, 'z_score' or 'min_max' |
| train_length | length of training data, default 6 |
| validate_length | length of validation data, default 2 |
| test_length | length of testing data, default 2 |
| n_route | number of time series, i.e. number of nodes |
| n_his | length of sliding window |
| n_pred | predict horizon |
| n_his | length of sliding window, default 12 |
| n_pred | predict horizon, default 3 |

**Table 1** Configurations for all datasets
| Dataset | train | evaluate | n_route | n_his | n_pred | scalar |
Expand Down

0 comments on commit 7d4b2d4

Please sign in to comment.