forked from luckyyangrun/pytorch-forecasting
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request sktime#39 from jdb78/feature/apply_isort
Apply isort
- Loading branch information
Showing
29 changed files
with
221 additions
and
145 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,13 +21,24 @@ jobs: | |
python-version: 3.8 | ||
|
||
- name: Install Python dependencies | ||
run: pip install black flake8 mypy | ||
run: pip install -U black flake8 mypy isort | ||
|
||
- name: Run linters | ||
- name: Run style linters | ||
uses: wearerequired/lint-action@v1 | ||
with: | ||
github_token: ${{ secrets.github_token }} | ||
# Enable linters | ||
black: true | ||
flake8: true | ||
# mypy: true | ||
|
||
- name: Run additional linters | ||
uses: ricardochaves/[email protected] | ||
with: | ||
python-root-list: "pytorch_forecasting examples tests" | ||
use-pylint: false | ||
use-pycodestyle: false | ||
use-flake8: false | ||
use-black: false | ||
use-mypy: false | ||
use-isort: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import copy | ||
from pathlib import Path | ||
import warnings | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
import torch | ||
|
||
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet | ||
from pytorch_forecasting.data import GroupNormalizer | ||
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss | ||
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters | ||
|
||
training = torch.load("train.pkl") | ||
validation = torch.load("valid.pkl") | ||
|
||
# create dataloaders for model | ||
batch_size = 128 # set this between 32 to 128 | ||
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) | ||
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0) | ||
|
||
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") | ||
lr_logger = LearningRateLogger() | ||
|
||
# configure network and trainer | ||
trainer = pl.Trainer( | ||
max_epochs=100, | ||
gpus=0, | ||
# clipping gradients is a hyperparameter and important to prevent divergance | ||
# of the gradient for recurrent neural networks | ||
gradient_clip_val=1e-3, | ||
limit_train_batches=30, | ||
# fast_dev_run=True, | ||
early_stop_callback=early_stop_callback, | ||
callbacks=[lr_logger], | ||
) | ||
|
||
|
||
tft = TemporalFusionTransformer.from_dataset( | ||
training, | ||
# not meaningful for finding the learning rate but otherwise very important | ||
learning_rate=0.15, | ||
hidden_size=16, # most important hyperparameter apart from learning rate | ||
# number of attention heads. Set to up to 4 for large datasets | ||
attention_head_size=1, | ||
dropout=0.1, # between 0.1 and 0.3 are good values | ||
hidden_continuous_size=8, # set to <= hidden_size | ||
output_size=7, # 7 quantiles by default | ||
loss=QuantileLoss(), | ||
log_interval=10, | ||
# reduce learning rate if no improvement in validation loss after x epochs | ||
# reduce_on_plateau_patience=4, | ||
) | ||
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k") | ||
|
||
# find optimal learning rate | ||
# res = trainer.lr_find( | ||
# tft, | ||
# train_dataloader=train_dataloader, | ||
# val_dataloaders=val_dataloader, | ||
# max_lr=10.0, | ||
# min_lr=1e-9, | ||
# early_stop_threshold=1e10, | ||
# ) | ||
|
||
# print(f"suggested learning rate: {res.suggestion()}") | ||
# fig = res.plot(show=True, suggest=True) | ||
# fig.show() | ||
|
||
trainer.fit(tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.