Skip to content

Commit

Permalink
Remove schedule_fn interface from the high-level interface
Browse files Browse the repository at this point in the history
This is impractical in practice.
  • Loading branch information
rlouf committed Nov 20, 2022
1 parent 9cb829e commit 9422b80
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 111 deletions.
110 changes: 41 additions & 69 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ def step_fn(rng_key: PRNGKey, state, delta: float):
class sgld:
"""Implements the (basic) user interface for the SGLD kernel.
The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias `blackjax.sgld.kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
parameters at initialization time, we provide a helper function that
specializes the general kernel.
The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias
`blackjax.sgld.kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Example
-------
Expand All @@ -476,35 +476,36 @@ class sgld:
.. code::
schedule_fn = lambda _: 1e-3
grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
We can now initialize the sgld kernel and the state:
.. code::
sgld = blackjax.sgld(grad_fn, schedule_fn)
sgld = blackjax.sgld(grad_fn)
state = sgld.init(position)
Assuming we have an iterator `batches` that yields batches of data we can perform one step:
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
.. code::
data_batch = next(batches)
new_state = sgld.step(rng_key, state, data_batch)
step_size = 1e-3
minibatch = next(batches)
new_state = sgld.step(rng_key, state, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sgld.step)
new_state, info = step(rng_key, state)
new_state, info = step(rng_key, state, minibatch, step_size)
Parameters
----------
gradient_estimator_fn
A function which, given a position and a batch of data, returns an estimation
of the value of the gradient of the log-posterior distribution at this position.
gradient_estimator
A tuple of functions that initialize and update the gradient estimation
state.
schedule_fn
A function which returns a step size given a step number.
Expand All @@ -519,42 +520,27 @@ class sgld:

def __new__( # type: ignore[misc]
cls,
grad_estimator_fn: Callable,
learning_rate: Union[Callable[[int], float], float],
grad_estimator: sgmcmc.gradients.GradientEstimator,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(grad_estimator_fn)

if callable(learning_rate):
learning_rate_fn = learning_rate
elif float(learning_rate):
step = cls.kernel(grad_estimator)

def learning_rate_fn(_):
return learning_rate
def init_fn(position: PyTree, minibatch: PyTree):
return cls.init(position, minibatch, grad_estimator)

else:
raise TypeError(
"The learning rate must either be a float (which corresponds to a constant learning rate) "
f"or a function of the index of the current iteration. Got {type(learning_rate)} instead."
)

def init_fn(position: PyTree, data_batch: PyTree):
return cls.init(position, data_batch, grad_estimator_fn)

def step_fn(rng_key: PRNGKey, state, data_batch: PyTree):
step_size = learning_rate_fn(state.step)
return step(rng_key, state, data_batch, step_size)
def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step(rng_key, state, minibatch, step_size)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]


class sghmc:
"""Implements the (basic) user interface for the SGHMC kernel.
The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias `blackjax.sghmc.kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
parameters at initialization time, we provide a helper function that
specializes the general kernel.
The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias
`blackjax.sghmc.kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.
Example
-------
Expand All @@ -565,35 +551,36 @@ class sghmc:
.. code::
schedule_fn = lambda _: 1e-3
grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
We can now initialize the sghmc kernel and the state. Like HMC, SGHMC needs the user to specify a number of integration steps.
.. code::
sghmc = blackjax.sghmc(grad_fn, schedule_fn, num_integration_steps)
sghmc = blackjax.sghmc(grad_estimator, num_integration_steps)
state = sghmc.init(position)
Assuming we have an iterator `batches` that yields batches of data we can perform one step:
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:
.. code::
data_batch = next(batches)
new_state = sghmc.step(rng_key, state, data_batch)
step_size = 1e-3
minibatch = next(batches)
new_state = sghmc.step(rng_key, state, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
.. code::
step = jax.jit(sghmc.step)
new_state, info = step(rng_key, state)
new_state, info = step(rng_key, state, minibatch, step_size)
Parameters
----------
gradient_estimator_fn
A function which, given a position and a batch of data, returns an estimation
of the value of the gradient of the log-posterior distribution at this position.
gradient_estimator
A tuple of functions that initialize and update the gradient estimation
state.
schedule_fn
A function which returns a step size given a step number.
Expand All @@ -608,32 +595,17 @@ class sghmc:

def __new__( # type: ignore[misc]
cls,
grad_estimator_fn: Callable,
learning_rate: Union[Callable[[int], float], float],
grad_estimator: sgmcmc.gradients.GradientEstimator,
num_integration_steps: int = 10,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(grad_estimator_fn)

if callable(learning_rate):
learning_rate_fn = learning_rate
elif float(learning_rate):

def learning_rate_fn(_):
return learning_rate

else:
raise TypeError(
"The learning rate must either be a float (which corresponds to a constant learning rate) "
f"or a function of the index of the current iteration. Got {type(learning_rate)} instead."
)
step = cls.kernel(grad_estimator)

def init_fn(position: PyTree, data_batch: PyTree):
return cls.init(position, data_batch, grad_estimator_fn)
def init_fn(position: PyTree, minibatch: PyTree):
return cls.init(position, minibatch, grad_estimator)

def step_fn(rng_key: PRNGKey, state, data_batch: PyTree):
step_size = learning_rate_fn(state.step)
return step(rng_key, state, data_batch, step_size, num_integration_steps)
def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step(rng_key, state, minibatch, step_size, num_integration_steps)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]

Expand Down
12 changes: 8 additions & 4 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def kernel(
integrator = sghmc(alpha, beta)

def one_step(
rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float, L: int
rng_key: PRNGKey,
state: SGLDState,
minibatch: PyTree,
step_size: float,
num_integration_steps: int,
) -> SGLDState:
def body_fn(state, rng_key):
position, momentum, grad_estimator_state = state
Expand All @@ -34,14 +38,14 @@ def body_fn(state, rng_key):
(position, grad_estimator_state),
)

step, position, grad_estimator_state = state
position, grad_estimator_state = state
momentum = generate_gaussian_noise(rng_key, position, step_size)
init_diffusion_state = (position, momentum, grad_estimator_state)

keys = jax.random.split(rng_key, L)
keys = jax.random.split(rng_key, num_integration_steps)
last_state, _ = jax.lax.scan(body_fn, init_diffusion_state, keys)
position, _, grad_estimator_state = last_state

return SGLDState(step + 1, position, grad_estimator_state)
return SGLDState(position, grad_estimator_state)

return one_step
12 changes: 3 additions & 9 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,13 @@


class SGLDState(NamedTuple):
step: int
position: PyTree
grad_estimator_state: GradientState


# We can compute the gradient at the begining of the kernel step
# This allows to get rid of much of the init function, AND
# Prevents a last useless gradient computation at the last step


def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator):
grad_estimator_state = gradient_estimator.init(minibatch)
return SGLDState(0, position, grad_estimator_state)
return SGLDState(position, grad_estimator_state)


def kernel(gradient_estimator: GradientEstimator) -> Callable:
Expand All @@ -32,12 +26,12 @@ def one_step(
rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float
):

step, position, grad_estimator_state = state
position, grad_estimator_state = state
logprob_grad, grad_estimator_state = gradient_estimator.estimate(
grad_estimator_state, position, minibatch
)
new_position = integrator(rng_key, position, logprob_grad, step_size, minibatch)

return SGLDState(step + 1, new_position, grad_estimator_state)
return SGLDState(new_position, grad_estimator_state)

return one_step
8 changes: 4 additions & 4 deletions examples/SGMCMC.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1]))
# Build the SGLD kernel with a constant learning rate
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn, lambda _: step_size)
sgld = blackjax.sgld(grad_fn)
state = sgld.init(init_positions, next(batches))
Expand All @@ -172,7 +172,7 @@ steps = []
for step in progress_bar(range(num_samples + num_warmup)):
_, rng_key = jax.random.split(rng_key)
batch = next(batches)
state = jax.jit(sgld.step)(rng_key, state, batch)
state = jax.jit(sgld.step)(rng_key, state, batch, step_size)
if step % 100 == 0:
accuracy = compute_accuracy(state.position, X_test, y_test)
accuracies.append(accuracy)
Expand Down Expand Up @@ -208,7 +208,7 @@ We can also use SGHMC to samples from this model
# Build the SGHMC kernel with a constant learning rate
step_size = 9e-6
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sghmc = blackjax.sghmc(grad_fn, lambda _: step_size)
sghmc = blackjax.sghmc(grad_fn)
# Batch the data
state = sghmc.init(init_positions, next(batches))
Expand All @@ -220,7 +220,7 @@ steps = []
for step in progress_bar(range(num_samples + num_warmup)):
_, rng_key = jax.random.split(rng_key)
batch = next(batches)
state = jax.jit(sghmc.step)(rng_key, state, batch)
state = jax.jit(sghmc.step)(rng_key, state, batch, step_size)
if step % 100 == 0:
sghmc_accuracy = compute_accuracy(state.position, X_test, y_test)
sghmc_accuracies.append(sghmc_accuracy)
Expand Down
Loading

0 comments on commit 9422b80

Please sign in to comment.