AutoDiff-Inference (Bijax)
This repository contains code for implementing Automatic Differentiation Variational Inference (ADVI) and different variants of Laplace Approximation based on major research papers.
- ADVI Implementation
- Laplace Approximation: Implementation of Laplace Approximation for constrained variables, inspired by Automatic Differentiation Variational Inference (ADVI).
## Creation of the dataset for Laplace Approximation
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
## Bernoulli likelihood function
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
# For Posterior distribution
alpha = prior_theta[0] + data.sum()
beta = prior_theta[1] + len(data) - data.sum()
Normal Laplace Approximation
## Using Identity bijector for normal Laplace Approximation
la = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Identity(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta) ## True posterior
fig = la.plot_approx_posterior(true_posterior=true_posterior)
plt.xlim(-0.5,1.5)
plt.figure()
plt.savefig("plots/la_coin_toss.png")
![image](https://private-user-images.githubusercontent.com/76394914/251864792-a808f463-2e2d-49a1-a122-813a3b0fd756.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNzc4MjMsIm5iZiI6MTczOTA3NzUyMywicGF0aCI6Ii83NjM5NDkxNC8yNTE4NjQ3OTItYTgwOGY0NjMtMmUyZC00OWExLWExMjItODEzYTNiMGZkNzU2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDA1MDUyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTUwYTNmMzk0M2VmYWIxMWU1Y2JkODY0OGJiNGRjZDJiYzk2MjE1Zjg1NzI5ZGFjN2YwYTJmMmFiZWY0YmM2NGUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.O2g7-_S-8abR-2C8VWWhQ1C55So0TbfjxzZe7A_wFP4)
Autodiff- Laplace Appoximation
## Using Sigmoid bijector for constrained Laplace Approximation
la_cov = LaplaceApproximation(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.Sigmoid(),
likelihood=likelihood_fn)
true_posterior = tfd.Beta(alpha, beta)
fig_cov = la_cov.plot_approx_posterior(true_posterior=true_posterior)
plt.figure()
plt.savefig("plots/la_cov_coin_toss.png")
fig = la_cov.plot_log_approx_posterior(true_posterior=true_posterior)
plt.savefig("plots/log_la_cov_coin_toss.png")
![image](https://private-user-images.githubusercontent.com/76394914/251865201-14390528-dfc3-4b6c-9c41-fa86b59b586e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNzc4MjMsIm5iZiI6MTczOTA3NzUyMywicGF0aCI6Ii83NjM5NDkxNC8yNTE4NjUyMDEtMTQzOTA1MjgtZGZjMy00YjZjLTljNDEtZmE4NmI1OWI1ODZlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDA1MDUyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWQzYTI5MmIxNWNjMTk4OTg3MmQwNDFhYjkyMGQwMDJhZjViODlhY2NjMDI1OTJkMDhmZDdhYmM5ZTM1NDNhNGMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.8nwvqU32_wuUr4M1mG9NAm111WV2qCGbbY0ng8c1QgE)
In addition to the implemented library for Laplace approximation, you'll find two additional notebooks showcasing diagonal Laplace approximation and low-rank Laplace approximation.
tfd = tfp.distributions
data_dist = tfd.Bernoulli(probs=0.7)
data = data_dist.sample(sample_shape=(100,), seed=jax.random.PRNGKey(3))
prior_theta = [3.0, 5.0]
def likelihood_fn(theta, data):
return tfd.Bernoulli(probs=theta).log_prob(data).sum()
advi = ADVI(
prior=tfd.Beta(prior_theta[0], prior_theta[1]),
bijector=tfp.bijectors.NormalCDF(),
likelihood=likelihood_fn,
)
appx_post = advi.approx_posterior(data)
![image](https://private-user-images.githubusercontent.com/76394914/251867418-d3d00ebc-9655-4d04-a5b8-2096dc759059.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwNzc4MjMsIm5iZiI6MTczOTA3NzUyMywicGF0aCI6Ii83NjM5NDkxNC8yNTE4Njc0MTgtZDNkMDBlYmMtOTY1NS00ZDA0LWE1YjgtMjA5NmRjNzU5MDU5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDA1MDUyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTkxZmJmYjBlM2M5ZDQ0NmZjOWRlOWU3Nzg2MDhkYWYxMjJhOTk1YjY4NzcwNTVmZDQ3MzI4MDFlOWMwMmQ2YTQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.QROhTIZZE5oDCBqZiZJaVynyB0ahA95o-aIgcZND_SI)