diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 6c01192825..2ab2f20280 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -66,6 +66,8 @@ __all__ = ( "get_jaxified_graph", "get_jaxified_logp", + "sample_blackjax_adjusted_mclmc", + "sample_blackjax_mclmc", "sample_blackjax_nuts", "sample_numpyro_nuts", ) @@ -259,8 +261,14 @@ def _blackjax_inference_loop( algorithm = blackjax.nuts elif algorithm_name == "hmc": algorithm = blackjax.hmc + elif algorithm_name == "mclmc": + algorithm = blackjax.mclmc + elif algorithm_name == "adjusted_mclmc": + algorithm = blackjax.adjusted_mclmc_dynamic else: - raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.") + raise ValueError( + "Only supporting 'nuts', 'hmc', 'mclmc', or 'adjusted_mclmc' as algorithm to draw samples." + ) adapt = blackjax.window_adaptation( algorithm=algorithm, @@ -726,3 +734,44 @@ def sample_jax_nuts( sample_numpyro_nuts = partial(sample_jax_nuts, nuts_sampler="numpyro") sample_blackjax_nuts = partial(sample_jax_nuts, nuts_sampler="blackjax") + + +# Custom partial functions for the MCLMC samplers +def sample_blackjax_mclmc(*args, **kwargs): + """ + Draw samples from the posterior using the MCLMC (Microcanonical Langevin Monte Carlo) method. + + From the ``blackjax`` library. + + Parameters are the same as for sample_jax_nuts. + + MCLMC is based on https://arxiv.org/abs/2212.08549. It numerically integrates a specialized SDE + to produce samples from the target distribution. + + This implementation uses the unadjusted MCLMC algorithm, which may have some asymptotic bias + but is typically faster than adjusted MCLMC or NUTS for many high-dimensional problems. + """ + kwargs.setdefault("nuts_sampler_kwargs", {}).setdefault("algorithm", "mclmc") + # MCLMC has different default target_accept + if "target_accept" not in kwargs: + kwargs["target_accept"] = 0.9 + return sample_jax_nuts(*args, nuts_sampler="blackjax", **kwargs) + + +def sample_blackjax_adjusted_mclmc(*args, **kwargs): + """ + Draw samples from the posterior using the adjusted MCLMC (Microcanonical Langevin Monte Carlo) method. + + From the ``blackjax`` library. + + Parameters are the same as for sample_jax_nuts. + + Adjusted MCLMC adds a Metropolis-Hastings correction step to the unadjusted MCLMC algorithm, + ensuring asymptotic unbiasedness at the cost of some computational efficiency. + It is recommended when exact asymptotic convergence to the posterior is required. + """ + kwargs.setdefault("nuts_sampler_kwargs", {}).setdefault("algorithm", "adjusted_mclmc") + # Adjusted MCLMC also has different default target_accept + if "target_accept" not in kwargs: + kwargs["target_accept"] = 0.9 + return sample_jax_nuts(*args, nuts_sampler="blackjax", **kwargs) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..d8a96acb7b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -292,6 +292,105 @@ def all_continuous(vars): return True +def _sample_mclmc( + sampler: Literal["blackjax"], + draws: int, + tune: int, + chains: int, + target_accept: float, + random_seed: RandomState | None, + initvals: StartDict | Sequence[StartDict | None] | None, + model: Model, + var_names: Sequence[str] | None, + progressbar: bool, + idata_kwargs: dict | None, + compute_convergence_checks: bool, + mclmc_sampler_kwargs: dict | None, + adjusted: bool = False, + **kwargs, +): + """ + Sample using the MCLMC sampler from BlackJax. + + Parameters + ---------- + sampler : {"blackjax"} + Name of the external sampler to use. + draws : int + Number of samples to draw. + tune : int + Number of tuning steps. + chains : int + Number of chains to run in parallel. + target_accept : float + Target acceptance rate. + random_seed : RandomState, optional + Random seed. + initvals : dict or list of dict, optional + Initial values for variables. + model : Model + PyMC model. + var_names : list of str, optional + List of variable names to sample. + progressbar : bool + Whether to display a progress bar. + idata_kwargs : dict, optional + Keyword arguments for InferenceData conversion. + compute_convergence_checks : bool + Whether to compute convergence checks. + mclmc_sampler_kwargs : dict, optional + Keyword arguments for the MCLMC sampler. + adjusted : bool, default=False + Whether to use the adjusted MCLMC algorithm. + + Returns + ------- + InferenceData + ArviZ InferenceData object with sampling results. + """ + if mclmc_sampler_kwargs is None: + mclmc_sampler_kwargs = {} + + if sampler == "blackjax": + import pymc.sampling.jax as pymc_jax + + if adjusted: + idata = pymc_jax.sample_blackjax_adjusted_mclmc( + draws=draws, + tune=tune, + chains=chains, + target_accept=target_accept, + random_seed=random_seed, + initvals=initvals, + model=model, + var_names=var_names, + progressbar=progressbar, + idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, + **mclmc_sampler_kwargs, + ) + else: + idata = pymc_jax.sample_blackjax_mclmc( + draws=draws, + tune=tune, + chains=chains, + target_accept=target_accept, + random_seed=random_seed, + initvals=initvals, + model=model, + var_names=var_names, + progressbar=progressbar, + idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, + **mclmc_sampler_kwargs, + ) + return idata + else: + raise ValueError( + f"Sampler {sampler} not found. Currently only 'blackjax' is supported for MCLMC." + ) + + def _sample_external_nuts( sampler: Literal["nutpie", "numpyro", "blackjax"], draws: int, @@ -429,6 +528,7 @@ def sample( step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + mclmc_sampler: Literal["blackjax"] | None = None, initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -440,6 +540,7 @@ def sample( return_inferencedata: Literal[True] = True, idata_kwargs: dict[str, Any] | None = None, nuts_sampler_kwargs: dict[str, Any] | None = None, + mclmc_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", @@ -461,6 +562,7 @@ def sample( step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + mclmc_sampler: Literal["blackjax"] | None = None, initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -472,6 +574,7 @@ def sample( return_inferencedata: Literal[False], idata_kwargs: dict[str, Any] | None = None, nuts_sampler_kwargs: dict[str, Any] | None = None, + mclmc_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, model: Model | None = None, @@ -493,6 +596,7 @@ def sample( step=None, var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + mclmc_sampler: Literal["blackjax"] | None = None, initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -504,6 +608,7 @@ def sample( return_inferencedata: bool = True, idata_kwargs: dict[str, Any] | None = None, nuts_sampler_kwargs: dict[str, Any] | None = None, + mclmc_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", @@ -823,6 +928,33 @@ def joined_blas_limiter(): **kwargs, ) + if mclmc_sampler is not None: + if not exclusive_nuts: + raise ValueError( + "Model can not be sampled with MCLMC alone. It either has discrete variables or a non-differentiable log-probability." + ) + + adjusted = mclmc_sampler_kwargs.pop("adjusted", False) if mclmc_sampler_kwargs else False + + with joined_blas_limiter(): + return _sample_mclmc( + sampler=mclmc_sampler, + draws=draws, + tune=tune, + chains=chains, + target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.9), + random_seed=random_seed, + initvals=initvals, + model=model, + var_names=var_names, + progressbar=progress_bool, + idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, + mclmc_sampler_kwargs=mclmc_sampler_kwargs, + adjusted=adjusted, + **kwargs, + ) + if exclusive_nuts and not provided_steps: # Special path for NUTS initialization if "nuts" in kwargs: diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 2d32277061..b5a44c1d6e 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -86,3 +86,34 @@ def test_step_args(): ) npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + + +def test_mclmc_sampling(): + """Test that blackjax MCLMC sampling works.""" + pytest.importorskip("blackjax") + with Model(): + a = Normal("a", 0, 1) + trace = sample( + draws=50, + tune=50, + random_seed=345, + mclmc_sampler="blackjax", + progressbar=False, + ) + assert "a" in trace.posterior.data_vars + + +def test_adjusted_mclmc_sampling(): + """Test that blackjax adjusted MCLMC sampling works.""" + pytest.importorskip("blackjax") + with Model(): + a = Normal("a", 0, 1) + trace = sample( + draws=50, + tune=50, + random_seed=345, + mclmc_sampler="blackjax", + mclmc_sampler_kwargs={"adjusted": True}, + progressbar=False, + ) + assert "a" in trace.posterior.data_vars