Skip to content

Commit

Permalink
Add introduction notebook (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
canyon289 authored Mar 18, 2021
1 parent 43a6e30 commit bd2a529
Showing 1 changed file with 275 additions and 0 deletions.
275 changes: 275 additions & 0 deletions notebooks/Introduction.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "alpine-armstrong",
"metadata": {},
"source": [
"# A quick introduction to Blackjax\n",
"BlackJAX is an MCMC sampling library based on Jax. Its explicitly designed to be modular allowing ease of access to any part of the algorithm, making it easy to introspect into how the sampler works or to extend it.\n",
"\n",
"In this notebook we provide a simple example based on basic Hamiltonian Monte Carlo to showcase the architecture and interfaces in the library"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "critical-reading",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import functools as ft\n",
"import blackjax.hmc as hmc\n",
"import jax.numpy as jnp\n",
"import jax.scipy.stats as stats\n",
"import jax"
]
},
{
"cell_type": "markdown",
"id": "judicial-brief",
"metadata": {},
"source": [
"## Generate simulated observed data\n",
"We'll generate observations from a distribution of known parameters to test if we can recover the parameters in sampling."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "advanced-kidney",
"metadata": {},
"outputs": [],
"source": [
"loc, scale = 10, 20\n",
"observed = np.random.normal(loc, scale, size=100)"
]
},
{
"cell_type": "markdown",
"id": "computational-petersburg",
"metadata": {},
"source": [
"## Create a potential function"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "ceramic-accessory",
"metadata": {},
"outputs": [],
"source": [
"def potential_fn(loc, scale, observed=observed):\n",
" \"\"\"Univariate Normal\"\"\"\n",
" logpdf = stats.norm.logpdf(observed, loc, scale)\n",
" return -jnp.sum(logpdf)\n",
"\n",
"potential = lambda x: potential_fn(**x)"
]
},
{
"cell_type": "markdown",
"id": "heated-nightmare",
"metadata": {},
"source": [
"## Set an initial state"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "transparent-passage",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HMCState(position={'loc': 1.0, 'scale': 2.0}, potential_energy=DeviceArray(6243.8994, dtype=float32), potential_energy_grad={'loc': DeviceArray(-245.63074, dtype=float32), 'scale': DeviceArray(-6032.691, dtype=float32)})"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_position = {\"loc\": 1.0, \"scale\": 2.0}\n",
"initial_state = hmc.new_state(initial_position, potential)\n",
"initial_state"
]
},
{
"cell_type": "markdown",
"id": "stylish-haven",
"metadata": {},
"source": [
"## Set some sampler parameters"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "lucky-electricity",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1., 1.])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_mass_matrix = np.array([1.0, 1.0])\n",
"inv_mass_matrix"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "deluxe-lesson",
"metadata": {},
"outputs": [],
"source": [
"params = hmc.HMCParameters(\n",
" num_integration_steps=90, step_size=1e-3, inv_mass_matrix=inv_mass_matrix\n",
")"
]
},
{
"cell_type": "markdown",
"id": "applied-translator",
"metadata": {},
"source": [
"## Combine both into a kernel"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "destroyed-ballet",
"metadata": {},
"outputs": [],
"source": [
"kernel = hmc.kernel(potential, params)"
]
},
{
"cell_type": "markdown",
"id": "solved-arrest",
"metadata": {},
"source": [
"## Create an inference loop"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "comparative-trinity",
"metadata": {},
"outputs": [],
"source": [
"def inference_loop(rng_key, kernel, initial_state, num_samples):\n",
" def one_step(state, rng_key):\n",
" state, _ = kernel(rng_key, state)\n",
" return state, state\n",
"\n",
" keys = jax.random.split(rng_key, num_samples)\n",
" _, states = jax.lax.scan(one_step, initial_state, keys)\n",
"\n",
" return states"
]
},
{
"cell_type": "markdown",
"id": "understanding-click",
"metadata": {},
"source": [
"## Run Sampling"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "fixed-fisher",
"metadata": {},
"outputs": [],
"source": [
"rng_key = jax.random.PRNGKey(19)\n",
"states = inference_loop(rng_key, kernel, initial_state, 20_000)\n",
"\n",
"loc_samples = states.position[\"loc\"][5000:]\n",
"scale_samples = states.position[\"scale\"][5000:]"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "substantial-prophet",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(10.872661, dtype=float32)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(loc_samples)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "grand-syndication",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(20.161758, dtype=float32)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(scale_samples)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit bd2a529

Please sign in to comment.