Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix failing plotting tests #6

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,13 +822,13 @@ def test_new_data_predict_method(


def test_get_valid_distribution(mmm):
normal_dist = mmm._get_distribution({"dist": "Normal"})
normal_dist = mmm._get_distribution_from_dict({"dist": "Normal"})
assert normal_dist is pm.Normal


def test_get_invalid_distribution(mmm):
with pytest.raises(ValueError, match="does not exist in PyMC"):
mmm._get_distribution({"dist": "NonExistentDist"})
mmm._get_distribution_from_dict({"dist": "NonExistentDist"})


def test_invalid_likelihood_type(mmm):
Expand Down
115 changes: 94 additions & 21 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.
import numpy as np
import pandas as pd
import pymc as pm
import pytest
from matplotlib import pyplot as plt

from pymc_marketing.mmm.delayed_saturated_mmm import BaseDelayedSaturatedMMM
from pymc_marketing.mmm.delayed_saturated_mmm import (
BaseDelayedSaturatedMMM,
DelayedSaturatedMMM,
)
from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget

seed: int = sum(map(ord, "pymc_marketing"))
Expand Down Expand Up @@ -49,6 +53,25 @@ def toy_y(toy_X) -> pd.Series:
return pd.Series(rng.integers(low=0, high=100, size=toy_X.shape[0]))


def mock_fit_base(model, X: pd.DataFrame, y: np.ndarray, **kwargs):
model.build_model(X=X, y=y)
with model.model:
idata = pm.sample_prior_predictive(random_seed=rng, **kwargs)

idata.add_groups(
{
"posterior": idata.prior,
"fit_data": pd.concat(
[X, pd.Series(y, index=X.index, name="y")], axis=1
).to_xarray(),
}
)
model.idata = idata
model.set_idata_attrs(idata=idata)

return model


class TestBasePlotting:
@pytest.fixture(
scope="module",
Expand Down Expand Up @@ -85,10 +108,7 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
channel_columns=["channel_1", "channel_2"],
)
# fit the model
mmm.fit(
X=toy_X,
y=toy_y,
)
mmm = mock_fit_base(mmm, toy_X, toy_y)
mmm.sample_prior_predictive(toy_X, toy_y, extend_idata=True, combined=True)
mmm.sample_posterior_predictive(toy_X, extend_idata=True, combined=True)
mmm._prior_predictive = mmm.prior_predictive
Expand All @@ -107,37 +127,90 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_errors", {}),
("plot_errors", {"original_scale": True}),
("plot_errors", {"ax": plt.subplots()[1]}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "adstock_alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
("plot_direct_contribution_curves", {}),
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
("plot_channel_contribution_share_hdi", {"hdi_prob": 0.95}),
("plot_grouped_contribution_breakdown_over_time", {}),
(
"plot_grouped_contribution_breakdown_over_time",
{
"stack_groups": {"controls": ["control_1"]},
"original_scale": True,
"area_kwargs": {"adstock_alpha": 0.5},
"area_kwargs": {"alpha": 0.5},
},
),
("plot_components_contributions", {}),
],
)
def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
func = plotting_mmm.__getattribute__(func_plot_name)
assert isinstance(func(**kwargs_plot), plt.Figure)
plt.close("all")

@pytest.mark.parametrize(
"channels, match",
[
(["invalid_channel"], "subset"),
(["channel_1", "channel_1"], "unique"),
([], "Number of rows must be a positive"),
],

@pytest.fixture(scope="module")
def mock_mmm():
return DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
)
def test_plot_direct_contribution_curves_error(self, plotting_mmm, channels, match):
with pytest.raises(ValueError, match=match):
plotting_mmm.plot_direct_contribution_curves(channels=channels)


def mock_fit(model: DelayedSaturatedMMM, X: pd.DataFrame, y: np.ndarray, **kwargs):
model.build_model(X=X, y=y)

with model.model:
idata = pm.sample_prior_predictive(random_seed=rng, **kwargs)

model.preprocess("X", X)
model.preprocess("y", y)

idata.add_groups(
{
"posterior": idata.prior,
"fit_data": pd.concat(
[X, pd.Series(y, index=X.index, name="y")], axis=1
).to_xarray(),
}
)
model.idata = idata
model.set_idata_attrs(idata=idata)

return model


@pytest.fixture(scope="module")
def mock_fitted_mmm(mock_mmm, toy_X, toy_y):
return mock_fit(mock_mmm, toy_X, toy_y)


@pytest.mark.parametrize(
argnames="func_plot_name, kwargs_plot",
argvalues=[
# Only part of DelayedSaturatedMMM now
("plot_direct_contribution_curves", {}),
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
("plot_channel_parameter", {"param_name": "adstock_alpha"}),
],
)
def test_delayed_saturated_mmm_plots(
mock_fitted_mmm, func_plot_name, kwargs_plot
) -> None:
func = mock_fitted_mmm.__getattribute__(func_plot_name)
assert isinstance(func(**kwargs_plot), plt.Figure)
plt.close("all")


@pytest.mark.parametrize(
"channels, match",
[
(["invalid_channel"], "subset"),
(["channel_1", "channel_1"], "unique"),
([], "Number of rows must be a positive"),
],
)
def test_plot_direct_contribution_curves_error(mock_fitted_mmm, channels, match):
with pytest.raises(ValueError, match=match):
mock_fitted_mmm.plot_direct_contribution_curves(channels=channels)