forked from blackjax-devs/blackjax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "alpine-armstrong", | ||
"metadata": {}, | ||
"source": [ | ||
"# A quick introduction to Blackjax\n", | ||
"BlackJAX is an MCMC sampling library based on Jax. Its explicitly designed to be modular allowing ease of access to any part of the algorithm, making it easy to introspect into how the sampler works or to extend it.\n", | ||
"\n", | ||
"In this notebook we provide a simple example based on basic Hamiltonian Monte Carlo to showcase the architecture and interfaces in the library" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"id": "critical-reading", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import functools as ft\n", | ||
"import blackjax.hmc as hmc\n", | ||
"import jax.numpy as jnp\n", | ||
"import jax.scipy.stats as stats\n", | ||
"import jax" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "judicial-brief", | ||
"metadata": {}, | ||
"source": [ | ||
"## Generate simulated observed data\n", | ||
"We'll generate observations from a distribution of known parameters to test if we can recover the parameters in sampling." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"id": "advanced-kidney", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"loc, scale = 10, 20\n", | ||
"observed = np.random.normal(loc, scale, size=100)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "computational-petersburg", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create a potential function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"id": "ceramic-accessory", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def potential_fn(loc, scale, observed=observed):\n", | ||
" \"\"\"Univariate Normal\"\"\"\n", | ||
" logpdf = stats.norm.logpdf(observed, loc, scale)\n", | ||
" return -jnp.sum(logpdf)\n", | ||
"\n", | ||
"potential = lambda x: potential_fn(**x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "heated-nightmare", | ||
"metadata": {}, | ||
"source": [ | ||
"## Set an initial state" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 35, | ||
"id": "transparent-passage", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"HMCState(position={'loc': 1.0, 'scale': 2.0}, potential_energy=DeviceArray(6243.8994, dtype=float32), potential_energy_grad={'loc': DeviceArray(-245.63074, dtype=float32), 'scale': DeviceArray(-6032.691, dtype=float32)})" | ||
] | ||
}, | ||
"execution_count": 35, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"initial_position = {\"loc\": 1.0, \"scale\": 2.0}\n", | ||
"initial_state = hmc.new_state(initial_position, potential)\n", | ||
"initial_state" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "stylish-haven", | ||
"metadata": {}, | ||
"source": [ | ||
"## Set some sampler parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"id": "lucky-electricity", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([1., 1.])" | ||
] | ||
}, | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"inv_mass_matrix = np.array([1.0, 1.0])\n", | ||
"inv_mass_matrix" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 37, | ||
"id": "deluxe-lesson", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"params = hmc.HMCParameters(\n", | ||
" num_integration_steps=90, step_size=1e-3, inv_mass_matrix=inv_mass_matrix\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "applied-translator", | ||
"metadata": {}, | ||
"source": [ | ||
"## Combine both into a kernel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 38, | ||
"id": "destroyed-ballet", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"kernel = hmc.kernel(potential, params)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "solved-arrest", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create an inference loop" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 39, | ||
"id": "comparative-trinity", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def inference_loop(rng_key, kernel, initial_state, num_samples):\n", | ||
" def one_step(state, rng_key):\n", | ||
" state, _ = kernel(rng_key, state)\n", | ||
" return state, state\n", | ||
"\n", | ||
" keys = jax.random.split(rng_key, num_samples)\n", | ||
" _, states = jax.lax.scan(one_step, initial_state, keys)\n", | ||
"\n", | ||
" return states" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "understanding-click", | ||
"metadata": {}, | ||
"source": [ | ||
"## Run Sampling" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"id": "fixed-fisher", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rng_key = jax.random.PRNGKey(19)\n", | ||
"states = inference_loop(rng_key, kernel, initial_state, 20_000)\n", | ||
"\n", | ||
"loc_samples = states.position[\"loc\"][5000:]\n", | ||
"scale_samples = states.position[\"scale\"][5000:]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"id": "substantial-prophet", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"DeviceArray(10.872661, dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 41, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"np.mean(loc_samples)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"id": "grand-syndication", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"DeviceArray(20.161758, dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"np.mean(scale_samples)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |