AutoDiff-Inference (Bijax)
This repository contains code for implementing Automatic Differentiation Variational Inference (ADVI) and different variants of Laplace Approximation based on major research papers.
- ADVI Implementation
- Laplace Approximation: Implementation of Laplace Approximation for constrained variables, inspired by Automatic Differentiation Variational Inference (ADVI).
## Creation of the dataset for Laplace Approximation
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
## Bernoulli likelihood function
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
# For Posterior distribution
alpha = prior_theta[0] + data.sum()
beta = prior_theta[1] + len(data) - data.sum()
Normal Laplace Approximation
## Using Identity bijector for normal Laplace Approximation
la = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Identity(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta) ## True posterior
fig = la.plot_approx_posterior(true_posterior=true_posterior)
plt.xlim(-0.5,1.5)
plt.figure()
plt.savefig("plots/la_coin_toss.png")
![image](https://private-user-images.githubusercontent.com/76394914/251864792-a808f463-2e2d-49a1-a122-813a3b0fd756.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNjg1MTYsIm5iZiI6MTczOTA2ODIxNiwicGF0aCI6Ii83NjM5NDkxNC8yNTE4NjQ3OTItYTgwOGY0NjMtMmUyZC00OWExLWExMjItODEzYTNiMGZkNzU2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDAyMzAxNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTAyOTExNjc5Y2UyYTEzOTA0MWZmOWFjNDA5MzA2MDRkY2I4MzIyZDYwOGZiZTI1NjFkYjQ2OGNjZmYwNTYzODcmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.dO_Ra22oZfWHm1ZqpafLCvuioW0sQTIzPlL_q9pUjHM)
Autodiff- Laplace Appoximation
## Using Sigmoid bijector for constrained Laplace Approximation
la_cov = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Sigmoid(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta)
fig_cov = la_cov.plot_approx_posterior(true_posterior=true_posterior)
plt.figure()
plt.savefig("plots/la_cov_coin_toss.png")
fig = la_cov.plot_log_approx_posterior(true_posterior=true_posterior)
plt.savefig("plots/log_la_cov_coin_toss.png")
![image](https://private-user-images.githubusercontent.com/76394914/251865201-14390528-dfc3-4b6c-9c41-fa86b59b586e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNjg1MTYsIm5iZiI6MTczOTA2ODIxNiwicGF0aCI6Ii83NjM5NDkxNC8yNTE4NjUyMDEtMTQzOTA1MjgtZGZjMy00YjZjLTljNDEtZmE4NmI1OWI1ODZlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDAyMzAxNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTdlYzdhYmUxOTU5ZmQ3Y2RkMDQwYTRmMmE3ZDk2MzViMzFkZjFlNjVmOGM3Zjg2NzBkYTUxM2RiNTc5OTg3OTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.kBeUOTDBFqrVB7vRraLOcyf4__b0jwCuRko5XXNJ3pY)
In addition to the implemented library for Laplace approximation, you'll find two additional notebooks showcasing diagonal Laplace approximation and low-rank Laplace approximation.
tfd = tfp.distributions
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
advi = ADVI(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.NormalCDF(),
likelihood=likelihood_fn,
)
appx_post = advi.approx_posterior(data)
![image](https://private-user-images.githubusercontent.com/76394914/251867418-d3d00ebc-9655-4d04-a5b8-2096dc759059.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNjg1MTYsIm5iZiI6MTczOTA2ODIxNiwicGF0aCI6Ii83NjM5NDkxNC8yNTE4Njc0MTgtZDNkMDBlYmMtOTY1NS00ZDA0LWE1YjgtMjA5NmRjNzU5MDU5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDAyMzAxNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTQzYzg0MDBhMjcyZTc0YWFhZGExNjNlMzdhNTI2MWYxMWNmYmE0ZDBmNTgwMzBlM2JlOWVmMDg5NjFlOGJhOTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.3_1c_sly0KvDGFhLu1AgLYq_j-_VCfFq7XzpdCURRbY)