Skip to content

Commit

Permalink
Merge pull request pycaret#3599 from pycaret/ts_plot_model_fix
Browse files Browse the repository at this point in the history
Fixes TS Plotting Issue
  • Loading branch information
ngupta23 authored Jun 10, 2023
2 parents c19adaf + 91edabb commit ab16135
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pycaret/internal/plots/utils/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def _resolve_hoverinfo(
hoverinfo: Optional[str],
threshold: int,
data: Optional[pd.Series],
X: Optional[pd.DataFrame],
X: Optional[List[pd.DataFrame]],
) -> str:
"""Decide whether data tip obtained by hovering over a Plotly plot should be
enabled or disabled based user settings and size of data. If user provides the
Expand Down Expand Up @@ -989,7 +989,7 @@ def _resolve_renderer(
renderer: Optional[str],
threshold: int,
data: Optional[pd.Series],
X: Optional[pd.DataFrame],
X: Optional[List[pd.DataFrame]],
) -> str:
"""Decide the renderer to use for the Plotly plot based user settings and
size of data. If user provides the `renderer` option, it is honored, else it
Expand Down
5 changes: 5 additions & 0 deletions pycaret/time_series/forecasting/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,6 +3778,11 @@ def _plot_model(
# Disable Prediction Intervals if more than 1 estimator is provided.
return_pred_int = False

if X is not None:
X = _reformat_dataframes_for_plots(
data=[X], labels_suffix=["original"]
)

elif plot == "insample":
# Try to get insample forecasts if possible
model_results = [
Expand Down
26 changes: 26 additions & 0 deletions tests/test_time_series_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_return_model_names_for_plots_stats,
)

from pycaret.datasets import get_data
from pycaret.time_series import TSForecastingExperiment

pytestmark = pytest.mark.filterwarnings("ignore::UserWarning")
Expand Down Expand Up @@ -352,3 +353,28 @@ def test_plot_multiple_model_overlays(
assert (
"Please provide a label corresponding to each model to proceed." in exceptionmsg
)


def test_plot_final_model_exo():
"""Tests running plot model after running finalize_model when exogenous
variables are present. Fix for https://github.com/pycaret/pycaret/issues/3565
"""
data = get_data("uschange")
target = "Consumption"
FH = 3
train = data.iloc[: int(len(data) - FH)]
test = data.iloc[int(len(data)) - FH :]
test = test.drop(columns=[target], axis=1)

exp = TSForecastingExperiment()
exp.setup(data=train, target=target, fh=FH, session_id=42)
model = exp.create_model("arima")
final_model = exp.finalize_model(model)

# Previous issue coming from renderer resolution due to X

# This should not give an error (passing X explicitly)
exp.plot_model(final_model, data_kwargs={"X": test})

# Also, plotting without explicit passing X should also pass
exp.plot_model()

0 comments on commit ab16135

Please sign in to comment.