Skip to content

Commit

Permalink
Merge pull request sktime#39 from jdb78/feature/apply_isort
Browse files Browse the repository at this point in the history
Apply isort
  • Loading branch information
jdb78 authored Sep 9, 2020
2 parents e245b22 + 0a9b389 commit a096fb8
Show file tree
Hide file tree
Showing 29 changed files with 221 additions and 145 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,24 @@ jobs:
python-version: 3.8

- name: Install Python dependencies
run: pip install black flake8 mypy
run: pip install -U black flake8 mypy isort

- name: Run linters
- name: Run style linters
uses: wearerequired/lint-action@v1
with:
github_token: ${{ secrets.github_token }}
# Enable linters
black: true
flake8: true
# mypy: true

- name: Run additional linters
uses: ricardochaves/[email protected]
with:
python-root-list: "pytorch_forecasting examples tests"
use-pylint: false
use-pycodestyle: false
use-flake8: false
use-black: false
use-mypy: false
use-isort: true
14 changes: 8 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
from pathlib import Path
import shutil
import subprocess
from pathlib import Path
import sys

from sphinx.ext import apidoc

SOURCE_PATH = Path(os.path.dirname(__file__)) # docs source
PROJECT_PATH = SOURCE_PATH.joinpath("../..") # project root
SOURCE_PATH = Path(os.path.dirname(__file__)) # noqa # docs source
PROJECT_PATH = SOURCE_PATH.joinpath("../..") # noqa # project root

sys.path.insert(0, str(PROJECT_PATH)) # noqa

sys.path.insert(0, str(PROJECT_PATH)) # isort:skip
import pytorch_forecasting # isort:skip

import pytorch_forecasting # noqa

# -- Project information -----------------------------------------------------

Expand Down
22 changes: 9 additions & 13 deletions examples/ar.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from pathlib import Path
import pickle
import warnings


import numpy as np
import pandas as pd
from pandas.core.common import SettingWithCopyWarning
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer, EncoderNormalizer
from pathlib import Path
import pandas as pd
import numpy as np

from pytorch_forecasting.metrics import MAE, PoissonLoss, QuantileLoss, SMAPE, RMSE
from pytorch_forecasting import EncoderNormalizer, GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.metrics import MAE, RMSE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.utils import profile
from pytorch_forecasting.data import NaNLabelEncoder

from pandas.core.common import SettingWithCopyWarning

warnings.simplefilter("error", category=SettingWithCopyWarning)


from pytorch_forecasting.data.examples import generate_ar_data

data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
data["static"] = "2"
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
Expand Down
8 changes: 5 additions & 3 deletions examples/nbeats.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import sys

sys.path.append("..")
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.preprocessing import scale

from pytorch_forecasting import TimeSeriesDataSet, NBeats
from pytorch_forecasting import NBeats, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from sklearn.preprocessing import scale
from pytorch_forecasting.data.examples import generate_ar_data

sys.path.append("..")


print("load data")
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
data["static"] = 2
Expand Down
19 changes: 8 additions & 11 deletions examples/stallion.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
from pathlib import Path
import pickle
import warnings

import numpy as np
import pandas as pd
from pandas.core.common import SettingWithCopyWarning
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer
from pathlib import Path
import pandas as pd
import numpy as np

from pytorch_forecasting.metrics import MAE, PoissonLoss, QuantileLoss, SMAPE, RMSE
from pytorch_forecasting import GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data.examples import get_stallion_data
from pytorch_forecasting.metrics import MAE, RMSE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.utils import profile

from pandas.core.common import SettingWithCopyWarning

warnings.simplefilter("error", category=SettingWithCopyWarning)


from pytorch_forecasting.data.examples import get_stallion_data

data = get_stallion_data()

data["month"] = data.date.dt.month.astype("str").astype("category")
Expand Down
73 changes: 73 additions & 0 deletions examples/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import copy
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

training = torch.load("train.pkl")
validation = torch.load("valid.pkl")

# create dataloaders for model
batch_size = 128 # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateLogger()

# configure network and trainer
trainer = pl.Trainer(
max_epochs=100,
gpus=0,
# clipping gradients is a hyperparameter and important to prevent divergance
# of the gradient for recurrent neural networks
gradient_clip_val=1e-3,
limit_train_batches=30,
# fast_dev_run=True,
early_stop_callback=early_stop_callback,
callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
training,
# not meaningful for finding the learning rate but otherwise very important
learning_rate=0.15,
hidden_size=16, # most important hyperparameter apart from learning rate
# number of attention heads. Set to up to 4 for large datasets
attention_head_size=1,
dropout=0.1, # between 0.1 and 0.3 are good values
hidden_continuous_size=8, # set to <= hidden_size
output_size=7, # 7 quantiles by default
loss=QuantileLoss(),
log_interval=10,
# reduce learning rate if no improvement in validation loss after x epochs
# reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
# res = trainer.lr_find(
# tft,
# train_dataloader=train_dataloader,
# val_dataloaders=val_dataloader,
# max_lr=10.0,
# min_lr=1e-9,
# early_stop_threshold=1e10,
# )

# print(f"suggested learning rate: {res.suggestion()}")
# fig = res.plot(show=True, suggest=True)
# fig.show()

trainer.fit(tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)
4 changes: 2 additions & 2 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
PyTorch Forecasting package for timeseries forecasting with PyTorch.
"""
from pytorch_forecasting.models import TemporalFusionTransformer, NBeats, Baseline
from pytorch_forecasting.data import TimeSeriesDataSet, GroupNormalizer, EncoderNormalizer
from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, TimeSeriesDataSet
from pytorch_forecasting.models import Baseline, NBeats, TemporalFusionTransformer

__all__ = [
"TimeSeriesDataSet",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Handling timeseries data is not trivial. It requires special treatment. This sub-package provides the necessary tools
to abstracts the necessary work.
"""
from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder, GroupNormalizer, TorchNormalizer, EncoderNormalizer

__all__ = ["TimeSeriesDataSet", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", "EncoderNormalizer"]
7 changes: 3 additions & 4 deletions pytorch_forecasting/data/encoders.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""
Encoders for encoding categorical variables and scaling continuous data.
"""
from typing import Dict, Iterable, List, Tuple, Union
import warnings
from typing import Union, Dict, List, Tuple, Iterable
import pandas as pd
import numpy as np

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

import torch
import torch.nn.functional as F

Expand Down
3 changes: 2 additions & 1 deletion pytorch_forecasting/data/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Example datasets for tutorials and testing.
"""
from pathlib import Path
import pandas as pd

import numpy as np
import pandas as pd
import requests

BASE_URL = "https://raw.github.com/jdb78/pytorch-forecasting/master/examples/data/"
Expand Down
20 changes: 8 additions & 12 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,23 @@
Timeseries data is special and has to be processed and fed to algorithms in a special way. This module
defines a class that is able to handle a wide variety of timeseries data problems.
"""
import warnings
from copy import deepcopy
import inspect
from typing import Union, Dict, List, Tuple, Any
from typing import Any, Dict, List, Tuple, Union
import warnings

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import pandas as pd
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_is_fitted
import torch
from torch.distributions import Beta
from torch.nn.utils import rnn
from torch.utils.data import Dataset, DataLoader

from sklearn.utils.validation import check_is_fitted
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

from pytorch_forecasting.data.encoders import NaNLabelEncoder, GroupNormalizer, EncoderNormalizer, TorchNormalizer
from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer


class TimeSeriesDataSet(Dataset):
Expand Down Expand Up @@ -887,14 +885,12 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:

if self._overwrite_values["variable"] in self.reals:
idx = self.reals.index(self._overwrite_values["variable"])
data_cont = data_cont # not to overwrite original data
data_cont[positions, idx] = self._overwrite_values["values"]
else:
assert (
self._overwrite_values["variable"] in self.flat_categoricals
), "overwrite values variable has to be either in real or categorical variables"
idx = self.flat_categoricals.index(self._overwrite_values["variable"])
data_cat = data_cat # not to overwrite original data
data_cat[positions, idx] = self._overwrite_values["values"]

return (
Expand Down
7 changes: 3 additions & 4 deletions pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""
Implementation of metrics for (mulit-horizon) timeseries forecasting.
"""
import abc
from typing import Dict, List, Union

from pytorch_lightning.metrics import TensorMetric
import scipy.stats
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import rnn
import abc
from pytorch_lightning.metrics import TensorMetric

import scipy.stats


class Metric(TensorMetric):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Models for timeseries forecasting.
"""
from pytorch_forecasting.models.base_model import BaseModel
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
from pytorch_forecasting.models.nbeats import NBeats
from pytorch_forecasting.models.baseline import Baseline
from pytorch_forecasting.models.nbeats import NBeats
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

__all__ = ["NBeats", "TemporalFusionTransformer", "BaseModel", "Baseline"]
29 changes: 13 additions & 16 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,27 @@
"""
from copy import deepcopy
import inspect
from pytorch_forecasting.data.encoders import GroupNormalizer
from torch import unsqueeze
from torch import optim
import cloudpickle

from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from pytorch_forecasting.metrics import SMAPE
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union

import cloudpickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics.metric import TensorMetric
from pytorch_forecasting.optim import Ranger
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
import torch
import numpy as np
import pandas as pd
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, OneCycleLR
from torch import optim, unsqueeze
from torch.optim.lr_scheduler import LambdaLR, OneCycleLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE
from pytorch_forecasting.optim import Ranger
from pytorch_forecasting.utils import groupby_apply

from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args


class BaseModel(LightningModule):
"""
Expand Down
Loading

0 comments on commit a096fb8

Please sign in to comment.