Skip to content

Implement nutpie as an external sampler #7719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pymc/step_methods/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""External samplers integration for PyMC."""

from pymc.step_methods.external.base import ExternalSampler
from pymc.step_methods.external.nutpie import NutPie

__all__ = ["ExternalSampler", "NutPie"]
124 changes: 124 additions & 0 deletions pymc/step_methods/external/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from arviz import InferenceData

from pymc.step_methods.compound import BlockedStep, Competence


class ExternalSampler(BlockedStep, ABC):
"""Base class for external samplers.

External samplers manage their own MCMC loop rather than using PyMC's.
These samplers (like NutPie, BlackJax, etc.) are designed to run
their own efficient loop inside their implementation.

Attributes
----------
is_external : bool
Flag indicating that this is an external sampler that needs
special handling in PyMC's sampling loops.
"""

is_external = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class already tells us this is an ExternalSampler, no need for is_external?


def __init__(
self,
vars=None,
model=None,
**kwargs,
):
"""Initialize external sampler.

Parameters
----------
vars : list, optional
Variables to be sampled
model : Model, optional
PyMC model
**kwargs
Sampler-specific arguments
"""
self.model = model
self._vars = vars
self._kwargs = kwargs

@abstractmethod
def sample(
self,
draws: int,
tune: int = 1000,
chains: int = 4,
random_seed=None,
initvals=None,
progressbar=True,
cores=None,
**kwargs,
) -> InferenceData:
"""Run external sampler and return results as InferenceData.

Parameters
----------
draws : int
Number of draws per chain
tune : int
Number of tuning draws per chain
chains : int
Number of chains to sample
random_seed : int or sequence, optional
Random seed(s) for reproducibility
initvals : dict or list of dict, optional
Initial values for variables
progressbar : bool
Whether to display progress bar
cores : int, optional
Number of CPU cores to use
**kwargs
Additional sampler-specific parameters

Returns
-------
InferenceData
ArviZ InferenceData object with sampling results
"""
pass

def step(self, point):
"""Do not use this method. External samplers use their own sampling loop.

External samplers do not use PyMC's step() mechanism.
"""
raise NotImplementedError(
"External samplers use their own sampling loop rather than PyMC's step() method."
)

@staticmethod
def competence(var, has_grad):
"""Determine competence level for sampling var.

Parameters
----------
var : Variable
Variable to be sampled
has_grad : bool
Whether gradient information is available

Returns
-------
Competence
Enum indicating competence level for this variable
"""
return Competence.COMPATIBLE
242 changes: 242 additions & 0 deletions pymc/step_methods/external/nutpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings

from typing import Literal

from arviz import InferenceData

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
from pymc.model import Model
from pymc.step_methods.compound import Competence
from pymc.step_methods.external.base import ExternalSampler
from pymc.vartypes import continuous_types

logger = logging.getLogger("pymc")

try:
import nutpie
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nutpie should only be imported if the step method is used/instantiated, so as to avoid import time penalty


# Check if it's actually installed and not just an empty mock module
NUTPIE_AVAILABLE = hasattr(nutpie, "compile_pymc_model")
except ImportError:
NUTPIE_AVAILABLE = False


class NutPie(ExternalSampler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be Nutpie. "pie" is not capitalized in the nutpie docs, anyway.

image

"""NutPie No-U-Turn Sampler.

This class provides an interface to the NutPie sampler, which is a high-performance
implementation of the No-U-Turn Sampler (NUTS). Unlike PyMC's native NUTS implementation,
NutPie samples chains sequentially in a single CPU, which can be more efficient for some
models.

Parameters
----------
vars : list, optional
Variables to be sampled
model : Model, optional
PyMC model
backend : {"numba", "jax"}, default="numba"
Which backend to use for computation
target_accept : float, default=0.8
Target acceptance rate for step size adaptation
max_treedepth : int, default=10
Maximum tree depth for NUTS (passed as 'maxdepth' to NutPie)
**kwargs
Additional parameters passed to nutpie.sample()

Notes
-----
Requires the nutpie package to be installed:
pip install nutpie
"""

name = "nutpie"

def __init__(
self,
vars=None,
*,
model=None,
backend: Literal["numba", "jax"] = "numba",
target_accept: float = 0.8,
max_treedepth: int = 10,
**kwargs,
):
"""Initialize NutPie sampler."""
if not NUTPIE_AVAILABLE:
raise ImportError("nutpie not found. Install it with: pip install nutpie")

super().__init__(vars=vars, model=model)

self.backend = backend
self.target_accept = target_accept
self.max_treedepth = max_treedepth
self.nutpie_kwargs = kwargs

def sample(
self,
draws: int,
tune: int = 1000,
chains: int = 4,
random_seed=None,
initvals=None,
progressbar=True,
cores=None,
idata_kwargs=None,
compute_convergence_checks=True,
**kwargs,
) -> InferenceData:
"""Run NutPie sampler and return results as InferenceData.

Parameters
----------
draws : int
Number of draws per chain
tune : int
Number of tuning draws per chain
chains : int
Number of chains to sample
random_seed : int or sequence, optional
Random seed(s) for reproducibility
initvals : dict or list of dict, optional
Initial values for variables (currently not used by NutPie)
progressbar : bool
Whether to display progress bar
cores : int, optional
Number of CPU cores to use (ignored by NutPie)
idata_kwargs : dict, optional
Additional arguments for arviz.InferenceData conversion
compute_convergence_checks : bool
Whether to compute convergence diagnostics
**kwargs
Additional sampler-specific parameters

Returns
-------
InferenceData
ArviZ InferenceData object with sampling results
"""
model = kwargs.pop("model", self.model)
if model is None:
model = Model.get_context()

# Handle variables
vars = kwargs.pop("vars", self._vars)
if vars is None:
vars = model.value_vars

# Create a NutPie model
logger.info("Compiling NutPie model")
nutpie_model = nutpie.compile_pymc_model(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to just call nutpie.sample?

model,
backend=self.backend,
)

# Set up sampling parameters - NutPie does this internally
# Keep these for other nutpie parameters to pass
nuts_kwargs = {
**self.nutpie_kwargs,
**kwargs,
}

if initvals is not None:
warnings.warn(
"`initvals` are currently not passed to nutpie sampler. "
"Use `init_mean` kwarg following nutpie specification instead.",
UserWarning,
)

# Set up random seed
if random_seed is not None:
nuts_kwargs["seed"] = random_seed

# Run the sampler
logger.info(
f"Running NutPie sampler with {chains} chains, {tune} tuning steps, and {draws} draws"
)

# Add target acceptance and max tree depth
nutpie_kwargs = {
"target_accept": self.target_accept,
"maxdepth": self.max_treedepth,
**nuts_kwargs,
}

# Update parameter names to match NutPie's API
if "progressbar" in nutpie_kwargs:
nutpie_kwargs["progress_bar"] = nutpie_kwargs.pop("progressbar")

# Pass progressbar from the sample function arguments
if progressbar is not None:
nutpie_kwargs["progress_bar"] = progressbar

# Call NutPie's sample function
nutpie_trace = nutpie.sample(
nutpie_model,
draws=draws,
tune=tune,
chains=chains,
**nutpie_kwargs,
)

# Convert to InferenceData
if idata_kwargs is None:
idata_kwargs = {}

# Extract relevant variables and data for InferenceData
coords, dims = coords_and_dims_for_inferencedata(model)
constants_data = find_constants(model)
observed_data = find_observations(model)

# Always include sampler stats
if "include_sampler_stats" not in idata_kwargs:
idata_kwargs["include_sampler_stats"] = True

# NutPie already returns an InferenceData object
idata = nutpie_trace

# Set tuning steps attribute if possible
try:
idata.posterior.attrs["tuning_steps"] = tune
except (AttributeError, KeyError):
logger.warning("Could not set tuning_steps attribute on InferenceData")

# Skip compute_convergence_checks for now
# NutPie's InferenceData structure is different from PyMC's expectations

return idata
Comment on lines +197 to +222
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of this is handled by nutpie


@staticmethod
def competence(var, has_grad):
"""Determine competence level for sampling var.

Parameters
----------
var : Variable
Variable to be sampled
has_grad : bool
Whether gradient information is available

Returns
-------
Competence
Enum indicating competence level for this variable
"""
if var.dtype in continuous_types and has_grad:
return Competence.IDEAL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be IDEAL

return Competence.INCOMPATIBLE
Loading
Loading