-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement nutpie as an external sampler #7719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2024 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""External samplers integration for PyMC.""" | ||
|
||
from pymc.step_methods.external.base import ExternalSampler | ||
from pymc.step_methods.external.nutpie import NutPie | ||
|
||
__all__ = ["ExternalSampler", "NutPie"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2024 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from arviz import InferenceData | ||
|
||
from pymc.step_methods.compound import BlockedStep, Competence | ||
|
||
|
||
class ExternalSampler(BlockedStep, ABC): | ||
"""Base class for external samplers. | ||
|
||
External samplers manage their own MCMC loop rather than using PyMC's. | ||
These samplers (like NutPie, BlackJax, etc.) are designed to run | ||
their own efficient loop inside their implementation. | ||
|
||
Attributes | ||
---------- | ||
is_external : bool | ||
Flag indicating that this is an external sampler that needs | ||
special handling in PyMC's sampling loops. | ||
""" | ||
|
||
is_external = True | ||
|
||
def __init__( | ||
self, | ||
vars=None, | ||
model=None, | ||
**kwargs, | ||
): | ||
"""Initialize external sampler. | ||
|
||
Parameters | ||
---------- | ||
vars : list, optional | ||
Variables to be sampled | ||
model : Model, optional | ||
PyMC model | ||
**kwargs | ||
Sampler-specific arguments | ||
""" | ||
self.model = model | ||
self._vars = vars | ||
self._kwargs = kwargs | ||
|
||
@abstractmethod | ||
def sample( | ||
self, | ||
draws: int, | ||
tune: int = 1000, | ||
chains: int = 4, | ||
random_seed=None, | ||
initvals=None, | ||
progressbar=True, | ||
cores=None, | ||
**kwargs, | ||
) -> InferenceData: | ||
"""Run external sampler and return results as InferenceData. | ||
|
||
Parameters | ||
---------- | ||
draws : int | ||
Number of draws per chain | ||
tune : int | ||
Number of tuning draws per chain | ||
chains : int | ||
Number of chains to sample | ||
random_seed : int or sequence, optional | ||
Random seed(s) for reproducibility | ||
initvals : dict or list of dict, optional | ||
Initial values for variables | ||
progressbar : bool | ||
Whether to display progress bar | ||
cores : int, optional | ||
Number of CPU cores to use | ||
**kwargs | ||
Additional sampler-specific parameters | ||
|
||
Returns | ||
------- | ||
InferenceData | ||
ArviZ InferenceData object with sampling results | ||
""" | ||
pass | ||
|
||
def step(self, point): | ||
"""Do not use this method. External samplers use their own sampling loop. | ||
|
||
External samplers do not use PyMC's step() mechanism. | ||
""" | ||
raise NotImplementedError( | ||
"External samplers use their own sampling loop rather than PyMC's step() method." | ||
) | ||
|
||
@staticmethod | ||
def competence(var, has_grad): | ||
"""Determine competence level for sampling var. | ||
|
||
Parameters | ||
---------- | ||
var : Variable | ||
Variable to be sampled | ||
has_grad : bool | ||
Whether gradient information is available | ||
|
||
Returns | ||
------- | ||
Competence | ||
Enum indicating competence level for this variable | ||
""" | ||
return Competence.COMPATIBLE |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# Copyright 2024 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import warnings | ||
|
||
from typing import Literal | ||
|
||
from arviz import InferenceData | ||
|
||
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations | ||
from pymc.model import Model | ||
from pymc.step_methods.compound import Competence | ||
from pymc.step_methods.external.base import ExternalSampler | ||
from pymc.vartypes import continuous_types | ||
|
||
logger = logging.getLogger("pymc") | ||
|
||
try: | ||
import nutpie | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nutpie should only be imported if the step method is used/instantiated, so as to avoid import time penalty |
||
|
||
# Check if it's actually installed and not just an empty mock module | ||
NUTPIE_AVAILABLE = hasattr(nutpie, "compile_pymc_model") | ||
except ImportError: | ||
NUTPIE_AVAILABLE = False | ||
|
||
|
||
class NutPie(ExternalSampler): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"""NutPie No-U-Turn Sampler. | ||
|
||
This class provides an interface to the NutPie sampler, which is a high-performance | ||
implementation of the No-U-Turn Sampler (NUTS). Unlike PyMC's native NUTS implementation, | ||
NutPie samples chains sequentially in a single CPU, which can be more efficient for some | ||
models. | ||
|
||
Parameters | ||
---------- | ||
vars : list, optional | ||
Variables to be sampled | ||
model : Model, optional | ||
PyMC model | ||
backend : {"numba", "jax"}, default="numba" | ||
Which backend to use for computation | ||
target_accept : float, default=0.8 | ||
Target acceptance rate for step size adaptation | ||
max_treedepth : int, default=10 | ||
Maximum tree depth for NUTS (passed as 'maxdepth' to NutPie) | ||
**kwargs | ||
Additional parameters passed to nutpie.sample() | ||
|
||
Notes | ||
----- | ||
Requires the nutpie package to be installed: | ||
pip install nutpie | ||
""" | ||
|
||
name = "nutpie" | ||
|
||
def __init__( | ||
self, | ||
vars=None, | ||
*, | ||
model=None, | ||
backend: Literal["numba", "jax"] = "numba", | ||
target_accept: float = 0.8, | ||
max_treedepth: int = 10, | ||
**kwargs, | ||
): | ||
"""Initialize NutPie sampler.""" | ||
if not NUTPIE_AVAILABLE: | ||
raise ImportError("nutpie not found. Install it with: pip install nutpie") | ||
|
||
super().__init__(vars=vars, model=model) | ||
|
||
self.backend = backend | ||
self.target_accept = target_accept | ||
self.max_treedepth = max_treedepth | ||
self.nutpie_kwargs = kwargs | ||
|
||
def sample( | ||
self, | ||
draws: int, | ||
tune: int = 1000, | ||
chains: int = 4, | ||
random_seed=None, | ||
initvals=None, | ||
progressbar=True, | ||
cores=None, | ||
idata_kwargs=None, | ||
compute_convergence_checks=True, | ||
**kwargs, | ||
) -> InferenceData: | ||
"""Run NutPie sampler and return results as InferenceData. | ||
|
||
Parameters | ||
---------- | ||
draws : int | ||
Number of draws per chain | ||
tune : int | ||
Number of tuning draws per chain | ||
chains : int | ||
Number of chains to sample | ||
random_seed : int or sequence, optional | ||
Random seed(s) for reproducibility | ||
initvals : dict or list of dict, optional | ||
Initial values for variables (currently not used by NutPie) | ||
progressbar : bool | ||
Whether to display progress bar | ||
cores : int, optional | ||
Number of CPU cores to use (ignored by NutPie) | ||
idata_kwargs : dict, optional | ||
Additional arguments for arviz.InferenceData conversion | ||
compute_convergence_checks : bool | ||
Whether to compute convergence diagnostics | ||
**kwargs | ||
Additional sampler-specific parameters | ||
|
||
Returns | ||
------- | ||
InferenceData | ||
ArviZ InferenceData object with sampling results | ||
""" | ||
model = kwargs.pop("model", self.model) | ||
if model is None: | ||
model = Model.get_context() | ||
|
||
# Handle variables | ||
vars = kwargs.pop("vars", self._vars) | ||
if vars is None: | ||
vars = model.value_vars | ||
|
||
# Create a NutPie model | ||
logger.info("Compiling NutPie model") | ||
nutpie_model = nutpie.compile_pymc_model( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to just call |
||
model, | ||
backend=self.backend, | ||
) | ||
|
||
# Set up sampling parameters - NutPie does this internally | ||
# Keep these for other nutpie parameters to pass | ||
nuts_kwargs = { | ||
**self.nutpie_kwargs, | ||
**kwargs, | ||
} | ||
|
||
if initvals is not None: | ||
warnings.warn( | ||
"`initvals` are currently not passed to nutpie sampler. " | ||
"Use `init_mean` kwarg following nutpie specification instead.", | ||
UserWarning, | ||
) | ||
|
||
# Set up random seed | ||
if random_seed is not None: | ||
nuts_kwargs["seed"] = random_seed | ||
|
||
# Run the sampler | ||
logger.info( | ||
f"Running NutPie sampler with {chains} chains, {tune} tuning steps, and {draws} draws" | ||
) | ||
|
||
# Add target acceptance and max tree depth | ||
nutpie_kwargs = { | ||
"target_accept": self.target_accept, | ||
"maxdepth": self.max_treedepth, | ||
**nuts_kwargs, | ||
} | ||
|
||
# Update parameter names to match NutPie's API | ||
if "progressbar" in nutpie_kwargs: | ||
nutpie_kwargs["progress_bar"] = nutpie_kwargs.pop("progressbar") | ||
|
||
# Pass progressbar from the sample function arguments | ||
if progressbar is not None: | ||
nutpie_kwargs["progress_bar"] = progressbar | ||
|
||
# Call NutPie's sample function | ||
nutpie_trace = nutpie.sample( | ||
nutpie_model, | ||
draws=draws, | ||
tune=tune, | ||
chains=chains, | ||
**nutpie_kwargs, | ||
) | ||
|
||
# Convert to InferenceData | ||
if idata_kwargs is None: | ||
idata_kwargs = {} | ||
|
||
# Extract relevant variables and data for InferenceData | ||
coords, dims = coords_and_dims_for_inferencedata(model) | ||
constants_data = find_constants(model) | ||
observed_data = find_observations(model) | ||
|
||
# Always include sampler stats | ||
if "include_sampler_stats" not in idata_kwargs: | ||
idata_kwargs["include_sampler_stats"] = True | ||
|
||
# NutPie already returns an InferenceData object | ||
idata = nutpie_trace | ||
|
||
# Set tuning steps attribute if possible | ||
try: | ||
idata.posterior.attrs["tuning_steps"] = tune | ||
except (AttributeError, KeyError): | ||
logger.warning("Could not set tuning_steps attribute on InferenceData") | ||
|
||
# Skip compute_convergence_checks for now | ||
# NutPie's InferenceData structure is different from PyMC's expectations | ||
|
||
return idata | ||
Comment on lines
+197
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think all of this is handled by nutpie |
||
|
||
@staticmethod | ||
def competence(var, has_grad): | ||
"""Determine competence level for sampling var. | ||
|
||
Parameters | ||
---------- | ||
var : Variable | ||
Variable to be sampled | ||
has_grad : bool | ||
Whether gradient information is available | ||
|
||
Returns | ||
------- | ||
Competence | ||
Enum indicating competence level for this variable | ||
""" | ||
if var.dtype in continuous_types and has_grad: | ||
return Competence.IDEAL | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It shouldn't be IDEAL |
||
return Competence.INCOMPATIBLE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class already tells us this is an ExternalSampler, no need for
is_external
?