diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index bdb63285c9d..b3c23b7e008 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -61,6 +61,10 @@ def __getattr__(name): warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning) return sum_to_1 + if name == "sum_to_1": + warnings.warn("sum_to_1 has been deprecated, use simplex instead.", FutureWarning) + return simplex + if name == "RVTransform": warnings.warn("RVTransform has been renamed to Transform", FutureWarning) return Transform @@ -334,6 +338,7 @@ def log_jac_det(self, value, *rv_inputs): Instantiation of :class:`pymc.logprob.transforms.LogTransform` for use in the ``transform`` argument of a random variable.""" +# Deprecated sum_to_1 = SumTo1() sum_to_1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.SumTo1` diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 8196c2623cf..3bf0c8baf1f 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -675,3 +675,6 @@ def test_deprecated_ndim_supp_transforms(): with pytest.warns(FutureWarning, match="deprecated"): assert tr.multivariate_sum_to_1 == tr.sum_to_1 + + with pytest.warns(FutureWarning, match="deprecated"): + assert tr.sum_to_1 == tr.simplex