diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml index cf41635d..222ecff7 100644 --- a/.github/workflows/code_quality.yml +++ b/.github/workflows/code_quality.yml @@ -21,9 +21,9 @@ jobs: python-version: 3.8 - name: Install Python dependencies - run: pip install black flake8 mypy + run: pip install -U black flake8 mypy isort - - name: Run linters + - name: Run style linters uses: wearerequired/lint-action@v1 with: github_token: ${{ secrets.github_token }} @@ -31,3 +31,14 @@ jobs: black: true flake8: true # mypy: true + + - name: Run additional linters + uses: ricardochaves/python-lint@v1.3.0 + with: + python-root-list: "pytorch_forecasting examples tests" + use-pylint: false + use-pycodestyle: false + use-flake8: false + use-black: false + use-mypy: false + use-isort: true diff --git a/docs/source/conf.py b/docs/source/conf.py index 80b83a25..f67233a5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,18 +11,20 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import os -import sys +from pathlib import Path import shutil import subprocess -from pathlib import Path +import sys + from sphinx.ext import apidoc -SOURCE_PATH = Path(os.path.dirname(__file__)) # docs source -PROJECT_PATH = SOURCE_PATH.joinpath("../..") # project root +SOURCE_PATH = Path(os.path.dirname(__file__)) # noqa # docs source +PROJECT_PATH = SOURCE_PATH.joinpath("../..") # noqa # project root + +sys.path.insert(0, str(PROJECT_PATH)) # noqa -sys.path.insert(0, str(PROJECT_PATH)) # isort:skip +import pytorch_forecasting # isort:skip -import pytorch_forecasting # noqa # -- Project information ----------------------------------------------------- diff --git a/examples/ar.py b/examples/ar.py index 7be9b8b3..b2ea145b 100644 --- a/examples/ar.py +++ b/examples/ar.py @@ -1,29 +1,25 @@ +from pathlib import Path import pickle import warnings - +import numpy as np +import pandas as pd +from pandas.core.common import SettingWithCopyWarning import pytorch_lightning as pl -import torch from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger from pytorch_lightning.loggers import TensorBoardLogger +import torch -from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer, EncoderNormalizer -from pathlib import Path -import pandas as pd -import numpy as np - -from pytorch_forecasting.metrics import MAE, PoissonLoss, QuantileLoss, SMAPE, RMSE +from pytorch_forecasting import EncoderNormalizer, GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data +from pytorch_forecasting.metrics import MAE, RMSE, SMAPE, PoissonLoss, QuantileLoss from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters from pytorch_forecasting.utils import profile -from pytorch_forecasting.data import NaNLabelEncoder - -from pandas.core.common import SettingWithCopyWarning warnings.simplefilter("error", category=SettingWithCopyWarning) -from pytorch_forecasting.data.examples import generate_ar_data - data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) data["static"] = "2" data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") diff --git a/examples/nbeats.py b/examples/nbeats.py index eeee3b8b..132af26f 100644 --- a/examples/nbeats.py +++ b/examples/nbeats.py @@ -1,15 +1,17 @@ import sys -sys.path.append("..") import pandas as pd import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping +from sklearn.preprocessing import scale -from pytorch_forecasting import TimeSeriesDataSet, NBeats +from pytorch_forecasting import NBeats, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder -from sklearn.preprocessing import scale from pytorch_forecasting.data.examples import generate_ar_data +sys.path.append("..") + + print("load data") data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) data["static"] = 2 diff --git a/examples/stallion.py b/examples/stallion.py index 2c551c82..25c8057e 100644 --- a/examples/stallion.py +++ b/examples/stallion.py @@ -1,27 +1,24 @@ +from pathlib import Path import pickle import warnings +import numpy as np +import pandas as pd +from pandas.core.common import SettingWithCopyWarning import pytorch_lightning as pl -import torch from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger from pytorch_lightning.loggers import TensorBoardLogger +import torch -from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, GroupNormalizer -from pathlib import Path -import pandas as pd -import numpy as np - -from pytorch_forecasting.metrics import MAE, PoissonLoss, QuantileLoss, SMAPE, RMSE +from pytorch_forecasting import GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet +from pytorch_forecasting.data.examples import get_stallion_data +from pytorch_forecasting.metrics import MAE, RMSE, SMAPE, PoissonLoss, QuantileLoss from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters from pytorch_forecasting.utils import profile -from pandas.core.common import SettingWithCopyWarning - warnings.simplefilter("error", category=SettingWithCopyWarning) -from pytorch_forecasting.data.examples import get_stallion_data - data = get_stallion_data() data["month"] = data.date.dt.month.astype("str").astype("category") diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 00000000..3ade3bb5 --- /dev/null +++ b/examples/test.py @@ -0,0 +1,73 @@ +import copy +from pathlib import Path +import warnings + +import numpy as np +import pandas as pd +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger +from pytorch_lightning.loggers import TensorBoardLogger +import torch + +from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet +from pytorch_forecasting.data import GroupNormalizer +from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss +from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters + +training = torch.load("train.pkl") +validation = torch.load("valid.pkl") + +# create dataloaders for model +batch_size = 128 # set this between 32 to 128 +train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) +val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0) + +early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") +lr_logger = LearningRateLogger() + +# configure network and trainer +trainer = pl.Trainer( + max_epochs=100, + gpus=0, + # clipping gradients is a hyperparameter and important to prevent divergance + # of the gradient for recurrent neural networks + gradient_clip_val=1e-3, + limit_train_batches=30, + # fast_dev_run=True, + early_stop_callback=early_stop_callback, + callbacks=[lr_logger], +) + + +tft = TemporalFusionTransformer.from_dataset( + training, + # not meaningful for finding the learning rate but otherwise very important + learning_rate=0.15, + hidden_size=16, # most important hyperparameter apart from learning rate + # number of attention heads. Set to up to 4 for large datasets + attention_head_size=1, + dropout=0.1, # between 0.1 and 0.3 are good values + hidden_continuous_size=8, # set to <= hidden_size + output_size=7, # 7 quantiles by default + loss=QuantileLoss(), + log_interval=10, + # reduce learning rate if no improvement in validation loss after x epochs + # reduce_on_plateau_patience=4, +) +print(f"Number of parameters in network: {tft.size()/1e3:.1f}k") + +# find optimal learning rate +# res = trainer.lr_find( +# tft, +# train_dataloader=train_dataloader, +# val_dataloaders=val_dataloader, +# max_lr=10.0, +# min_lr=1e-9, +# early_stop_threshold=1e10, +# ) + +# print(f"suggested learning rate: {res.suggestion()}") +# fig = res.plot(show=True, suggest=True) +# fig.show() + +trainer.fit(tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader) diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index a2fa1017..f7365591 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -1,8 +1,8 @@ """ PyTorch Forecasting package for timeseries forecasting with PyTorch. """ -from pytorch_forecasting.models import TemporalFusionTransformer, NBeats, Baseline -from pytorch_forecasting.data import TimeSeriesDataSet, GroupNormalizer, EncoderNormalizer +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, TimeSeriesDataSet +from pytorch_forecasting.models import Baseline, NBeats, TemporalFusionTransformer __all__ = [ "TimeSeriesDataSet", diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index bfa8ef67..a71bd9bc 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -4,7 +4,7 @@ Handling timeseries data is not trivial. It requires special treatment. This sub-package provides the necessary tools to abstracts the necessary work. """ +from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer from pytorch_forecasting.data.timeseries import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder, GroupNormalizer, TorchNormalizer, EncoderNormalizer __all__ = ["TimeSeriesDataSet", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", "EncoderNormalizer"] diff --git a/pytorch_forecasting/data/encoders.py b/pytorch_forecasting/data/encoders.py index 38961701..7b431c08 100644 --- a/pytorch_forecasting/data/encoders.py +++ b/pytorch_forecasting/data/encoders.py @@ -1,13 +1,12 @@ """ Encoders for encoding categorical variables and scaling continuous data. """ +from typing import Dict, Iterable, List, Tuple, Union import warnings -from typing import Union, Dict, List, Tuple, Iterable -import pandas as pd -import numpy as np +import numpy as np +import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin - import torch import torch.nn.functional as F diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index 762733af..1b0149d2 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -2,8 +2,9 @@ Example datasets for tutorials and testing. """ from pathlib import Path -import pandas as pd + import numpy as np +import pandas as pd import requests BASE_URL = "https://raw.github.com/jdb78/pytorch-forecasting/master/examples/data/" diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 35f87791..6c6bf3f3 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -4,25 +4,23 @@ Timeseries data is special and has to be processed and fed to algorithms in a special way. This module defines a class that is able to handle a wide variety of timeseries data problems. """ -import warnings from copy import deepcopy import inspect -from typing import Union, Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple, Union +import warnings import matplotlib.pyplot as plt -import pandas as pd import numpy as np - +import pandas as pd +from sklearn.exceptions import NotFittedError +from sklearn.preprocessing import StandardScaler +from sklearn.utils.validation import check_is_fitted import torch from torch.distributions import Beta from torch.nn.utils import rnn -from torch.utils.data import Dataset, DataLoader - -from sklearn.utils.validation import check_is_fitted -from sklearn.exceptions import NotFittedError -from sklearn.preprocessing import StandardScaler +from torch.utils.data import DataLoader, Dataset -from pytorch_forecasting.data.encoders import NaNLabelEncoder, GroupNormalizer, EncoderNormalizer, TorchNormalizer +from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TorchNormalizer class TimeSeriesDataSet(Dataset): @@ -887,14 +885,12 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: if self._overwrite_values["variable"] in self.reals: idx = self.reals.index(self._overwrite_values["variable"]) - data_cont = data_cont # not to overwrite original data data_cont[positions, idx] = self._overwrite_values["values"] else: assert ( self._overwrite_values["variable"] in self.flat_categoricals ), "overwrite values variable has to be either in real or categorical variables" idx = self.flat_categoricals.index(self._overwrite_values["variable"]) - data_cat = data_cat # not to overwrite original data data_cat[positions, idx] = self._overwrite_values["values"] return ( diff --git a/pytorch_forecasting/metrics.py b/pytorch_forecasting/metrics.py index ceaf6eaa..7dd8ab2d 100644 --- a/pytorch_forecasting/metrics.py +++ b/pytorch_forecasting/metrics.py @@ -1,16 +1,15 @@ """ Implementation of metrics for (mulit-horizon) timeseries forecasting. """ +import abc from typing import Dict, List, Union +from pytorch_lightning.metrics import TensorMetric +import scipy.stats import torch from torch import nn import torch.nn.functional as F from torch.nn.utils import rnn -import abc -from pytorch_lightning.metrics import TensorMetric - -import scipy.stats class Metric(TensorMetric): diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index fa0fb599..5d3db881 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -2,8 +2,8 @@ Models for timeseries forecasting. """ from pytorch_forecasting.models.base_model import BaseModel -from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer -from pytorch_forecasting.models.nbeats import NBeats from pytorch_forecasting.models.baseline import Baseline +from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer __all__ = ["NBeats", "TemporalFusionTransformer", "BaseModel", "Baseline"] diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 7f568cc4..d429daa7 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -3,30 +3,27 @@ """ from copy import deepcopy import inspect -from pytorch_forecasting.data.encoders import GroupNormalizer -from torch import unsqueeze -from torch import optim -import cloudpickle - -from torch.utils.data import DataLoader -from tqdm.notebook import tqdm - -from pytorch_forecasting.metrics import SMAPE from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +import cloudpickle +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd from pytorch_lightning import LightningModule from pytorch_lightning.metrics.metric import TensorMetric -from pytorch_forecasting.optim import Ranger +from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args import torch -import numpy as np -import pandas as pd -from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, OneCycleLR +from torch import optim, unsqueeze +from torch.optim.lr_scheduler import LambdaLR, OneCycleLR, ReduceLROnPlateau +from torch.utils.data import DataLoader +from tqdm.notebook import tqdm -import matplotlib.pyplot as plt from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import GroupNormalizer +from pytorch_forecasting.metrics import SMAPE +from pytorch_forecasting.optim import Ranger from pytorch_forecasting.utils import groupby_apply -from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args - class BaseModel(LightningModule): """ diff --git a/pytorch_forecasting/models/baseline.py b/pytorch_forecasting/models/baseline.py index 45d8be87..5565a8dd 100644 --- a/pytorch_forecasting/models/baseline.py +++ b/pytorch_forecasting/models/baseline.py @@ -2,11 +2,12 @@ Baseline model. """ from typing import Dict + import torch from torch.nn.utils import rnn -from pytorch_forecasting.models import BaseModel from pytorch_forecasting.metrics import MultiHorizonMetric, QuantileLoss +from pytorch_forecasting.models import BaseModel class Baseline(BaseModel): diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index c6667e77..51a2ef37 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -2,15 +2,15 @@ N-Beats model for timeseries forecasting without covariates. """ from typing import Dict, List -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import torch from torch import nn from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import SMAPE from pytorch_forecasting.models.base_model import BaseModel -from pytorch_forecasting.models.nbeats.sub_modules import NBEATSTrendBlock, NBEATSGenericBlock, NBEATSSeasonalBlock +from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock class NBeats(BaseModel): diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index a1ed9af6..894f96c0 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -2,8 +2,9 @@ Implementation of ``nn.Modules`` for N-Beats model. """ from typing import Tuple -import torch + import numpy as np +import torch import torch.nn as nn import torch.nn.functional as F diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 7c222371..2bdd7980 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -1,27 +1,27 @@ """ The temporal fusion transformer is a powerful predictive model for forecasting timeseries """ -from typing import Callable, Union, List, Dict, Tuple +from typing import Callable, Dict, List, Tuple, Union +from matplotlib import pyplot as plt import numpy as np import torch -from matplotlib import pyplot as plt from torch import nn from torch.nn.utils import rnn -from pytorch_forecasting.models.base_model import BaseModel, CovariatesMixin from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.metrics import MultiHorizonMetric, QuantileLoss, SMAPE, MAE, RMSE, MAPE +from pytorch_forecasting.metrics import MAE, MAPE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss +from pytorch_forecasting.models.base_model import BaseModel, CovariatesMixin from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( - VariableSelectionNetwork, - GatedResidualNetwork, + AddNorm, GateAddNorm, - InterpretableMultiHeadAttention, GatedLinearUnit, - AddNorm, + GatedResidualNetwork, + InterpretableMultiHeadAttention, TimeDistributedEmbeddingBag, + VariableSelectionNetwork, ) -from pytorch_forecasting.utils import autocorrelation, integer_histogram, get_embedding_size +from pytorch_forecasting.utils import autocorrelation, get_embedding_size, integer_histogram class TemporalFusionTransformer(BaseModel, CovariatesMixin): diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py b/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py index b3d39407..446b5ec1 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py @@ -1,12 +1,11 @@ """ Implementation of ``nn.Modules`` for temporal fusion transformer. """ -from typing import Union, List, Dict, Tuple - import math +from typing import Dict, List, Tuple, Union -import torch.nn as nn import torch +import torch.nn as nn import torch.nn.functional as F diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 73f932dc..c44d787c 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -2,21 +2,21 @@ Hyperparameters can be efficiently tuned with `optuna `_. """ import os -from typing import Dict, Tuple, Any +from typing import Any, Dict, Tuple -import optuna -import torch import numpy as np -import statsmodels.api as sm +import optuna +from optuna.integration import PyTorchLightningPruningCallback, TensorBoardCallback +import pytorch_lightning as pl from pytorch_lightning import Callback from pytorch_lightning.callbacks import LearningRateLogger from pytorch_lightning.loggers import TensorBoardLogger +import statsmodels.api as sm +import torch from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet -import pytorch_lightning as pl -from optuna.integration import PyTorchLightningPruningCallback, TensorBoardCallback class MetricsCallback(Callback): diff --git a/pytorch_forecasting/optim.py b/pytorch_forecasting/optim.py index 121d2676..ee82b2d2 100644 --- a/pytorch_forecasting/optim.py +++ b/pytorch_forecasting/optim.py @@ -2,11 +2,11 @@ Optimizers not provided by PyTorch natively. """ import math +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union + import torch from torch.optim.optimizer import Optimizer -from typing import Iterable, Union, Callable, Dict, Optional, Tuple, Any - Params = Union[Iterable[torch.Tensor], Iterable[dict]] LossClosure = Callable[[], float] diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils.py index 6b68c203..c3ed26e6 100644 --- a/pytorch_forecasting/utils.py +++ b/pytorch_forecasting/utils.py @@ -1,11 +1,12 @@ """ Helper functions for PyTorch forecasting """ -import os +from contextlib import redirect_stdout import io -from typing import Callable, Union, Tuple +import os +from typing import Callable, Tuple, Union + import torch -from contextlib import redirect_stdout def integer_histogram( diff --git a/setup.cfg b/setup.cfg index 961f572d..e2afdcdc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,20 +23,12 @@ exclude = docs/build/*.py, .history/* [isort] -combine_as_imports = True -include_trailing_comma = True -multi_line_output = 3 -not_skip = __init__.py - -known_standard_library = dataclasses,typing_extensions -known_third_party = click,log -known_first_party = temporal_fusion_transformer_pytorch -force_grid_wrap = false - -lines_after_imports = 2 +profile = black +honor_noqa = true line_length = 120 -ensure_newline_before_comments = true - +combine_as_imports = true +force_sort_within_sections = true +known_first_party = pytorch_forecasting [tool:pytest] addopts = @@ -55,7 +47,7 @@ markers = [coverage:report] ignore_errors = False -show_missing = True +show_missing = true [mypy] diff --git a/tests/conftest.py b/tests/conftest.py index 9c394235..147fcd74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ -import pytest -import numpy as np -import sys import os +import sys + +import numpy as np +import pytest + +sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip -sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) -sys.path.insert(0, "examples") -from pytorch_forecasting.data.examples import get_stallion_data -from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting import TimeSeriesDataSet # isort:skip +from pytorch_forecasting.data.examples import get_stallion_data # isort:skip @pytest.fixture diff --git a/tests/test_data.py b/tests/test_data.py index 973665b8..a16aa5fb 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,15 +1,15 @@ -import pytest -from typing import Dict from copy import deepcopy import itertools -import torch +from typing import Dict + import numpy as np import pandas as pd - +import pytest from sklearn.preprocessing import StandardScaler +import torch +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder, TimeSeriesDataSet from pytorch_forecasting.data.examples import get_stallion_data -from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer, TimeSeriesDataSet, EncoderNormalizer torch.manual_seed(23) diff --git a/tests/test_models/conftest.py b/tests/test_models/conftest.py index 9e80e258..1992cee4 100644 --- a/tests/test_models/conftest.py +++ b/tests/test_models/conftest.py @@ -1,9 +1,10 @@ -import pytest import numpy as np +import pytest import torch -from pytorch_forecasting.data.examples import get_stallion_data, generate_ar_data + from pytorch_forecasting import TimeSeriesDataSet -from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder, EncoderNormalizer +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data @pytest.fixture @@ -100,7 +101,7 @@ def multiple_dataloaders_with_coveratiates(data_with_covariates, request): validation = TimeSeriesDataSet.from_dataset( training, data_with_covariates, min_prediction_idx=training.index.time.max() + 1 ) - batch_size = 32 + batch_size = 4 train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) @@ -130,7 +131,7 @@ def dataloaders_with_coveratiates(data_with_covariates): validation = TimeSeriesDataSet.from_dataset( training, data_with_covariates, min_prediction_idx=training.index.time.max() + 1 ) - batch_size = 32 + batch_size = 4 train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) @@ -163,7 +164,7 @@ def dataloaders_fixed_window_without_coveratiates(): data[lambda x: x.series.isin(validation)], stop_randomization=True, ) - batch_size = 64 + batch_size = 4 train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index 0ef3399e..d55499e9 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -1,7 +1,9 @@ import shutil + import pytorch_lightning as pl -from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + from pytorch_forecasting.metrics import QuantileLoss from pytorch_forecasting.models import NBeats diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 9443c6bf..a16e8ecf 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -1,14 +1,15 @@ -import pytest - -import torch import shutil import sys + +import pytest import pytorch_lightning as pl -from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_forecasting.metrics import QuantileLoss, PoissonLoss -from pytorch_forecasting.models import TemporalFusionTransformer +from pytorch_lightning.loggers import TensorBoardLogger +import torch +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss +from pytorch_forecasting.models import TemporalFusionTransformer if sys.version.startswith("3.6"): # python 3.6 does not have nullcontext from contextlib import contextmanager @@ -106,9 +107,12 @@ def model(dataloaders_with_coveratiates): @pytest.mark.parametrize("kwargs", [dict(mode="dataframe"), dict(mode="series"), dict(mode="raw")]) def test_predict_dependency(model, dataloaders_with_coveratiates, data_with_covariates, kwargs): - dataset = dataloaders_with_coveratiates["val"].dataset + train_dataset = dataloaders_with_coveratiates["train"].dataset + dataset = TimeSeriesDataSet.from_dataset( + train_dataset, data_with_covariates[lambda x: x.agency == data_with_covariates.agency.iloc[0]], predict=True + ) model.predict_dependency(dataset, variable="discount", values=[0.1, 0.0], **kwargs) - model.predict_dependency(dataset, variable="agency", values=data_with_covariates.agency.unique(), **kwargs) + model.predict_dependency(dataset, variable="agency", values=data_with_covariates.agency.unique()[:2], **kwargs) def test_actual_vs_predicted_plot(model, dataloaders_with_coveratiates): diff --git a/tests/test_utils/test_autocorrelation.py b/tests/test_utils/test_autocorrelation.py index 0a0745c7..f27adef8 100644 --- a/tests/test_utils/test_autocorrelation.py +++ b/tests/test_utils/test_autocorrelation.py @@ -1,6 +1,7 @@ import math import torch + from pytorch_forecasting.utils import autocorrelation