This repository contains the code to reproduce the experiments in the paper
What Are Bayesian Neural Network Posteriors Really Like?
by Pavel Izmailov, Sharad Vikram, Matthew D. Hoffman and Andrew Gordon Wilson.
In the paper, we use full-batch Hamiltonian Monte Carlo (HMC) to investigate foundational questions in Bayesian deep learning. We show that
- BNNs can achieve significant performance gains over standard training and deep ensembles;
- a single long HMC chain can provide a comparable representation of the posterior to multiple shorter chains;
- in contrast to recent studies, we find posterior tempering is not needed for near-optimal performance, with little evidence for a ``cold posterior'' effect, which we show is largely an artifact of data augmentation;
- BMA performance is robust to the choice of prior scale, and relatively similar for diagonal Gaussian, mixture of Gaussian, and logistic priors;
- Bayesian neural networks show surprisingly poor generalization under domain shift;
- while cheaper alternatives such as deep ensembles and SGMCMC can provide good generalization, they provide distinct predictive distributions from HMC. Notably, deep ensemble predictive distributions are similarly close to HMC as standard SGLD, and closer than standard variational inference.
In this repository we provide JAX code for reproducing results in the paper.
We use provide a requirements.txt
file that can be used to create a conda
environment to run the code in this repo:
conda create --name <env> --file requirements.txt
Example set-up using pip
:
pip install tensorflow
pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.65+cuda112 -f \
https://storage.googleapis.com/jax-releases/jax_releases.html
pip install git+https://github.com/deepmind/dm-haiku
pip install tensorflow_datasets
pip install tabulate
pip install optax
Please see the JAX repo for the latest instructions on how to install JAX on your hardware.
.
+-- core/
| +-- hmc.py (The Hamiltonian Monte Carlo algorithm)
| +-- sgmcmc.py (SGMCMC methods as optax optimizers)
| +-- vi.py (Mean field variational inference)
+-- utils/ (Utility functions used by the training scripts)
| +-- train_utils.py (The training epochs and update rules)
| +-- models.py (Models used in the experiments)
| +-- losses.py (Prior and likelihood functions)
| +-- data_utils.py (Loading and pre-processing the data)
| +-- optim_utils.py (Optimizers and learning rate schedules)
| +-- ensemble_utils.py (Implementation of ensembling of predictions)
| +-- metrics.py (Metrics used in evaluation)
| +-- cmd_args_utils.py (Common command line arguments)
| +-- script_utils.py (Common functionality of the training scripts)
| +-- checkpoint_utils.py (Saving and loading checkpoints)
| +-- logging_utils.py (Utilities for logging printing the results)
| +-- precision_utils.py (Controlling the numerical precision)
| +-- tree_utils.py (Common operations on pytree objects)
+-- run_hmc.py (HMC training script)
+-- run_sgd.py (SGD training script)
+-- run_sgmcmc.py (SGMCMC training script)
+-- run_vi.py (MFVI training script)
+-- make_posterior_surface_plot.py (script to visualize posterior density)
Common command line arguments:
seed
— random seeddir
— training directory for saving the checkpoints and tensorboard logsdataset_name
— name of the dataset, e.g.cifar10
,cifar100
,imdb
; for the UCI datasets, the name is specified as<UCI dataset name>_<random seed>
, e.g.yacht_2
, where the seed determines the train-test splitsubset_train_to
— number of datapoints to use from the dataset; by default, the full dataset is usedmodel_name
— name of the neural network architecture, e.g.lenet
,resnet20_frn_swish
,cnn_lstm
,mlp_regression_small
weight_decay
— weight decay; for Bayesian methods, weight decay determines the prior variance (prior_var = 1 / weight_decay
)temperature
— posterior temperature (default:1
)init_checkpoint
— path to the checkpoint to use for initialization (optional)tabulate_freq
— frequency of tabulate table header logginguse_float64
— use float64 precision (does not work on TPUs and some GPUs); by default, we usefloat32
precision
To run HMC, you can use the run_hmc.py
training script. Arguments:
step_size
— HMC step sizetrajectory_len
— HMC trajectory lengthnum_iterations
— Total number of HMC iterationsmax_num_leapfrog_steps
— Maximum number of leapfrog steps allowed; meant as a sanity check and should be greater thantrajectory_len / step_size
num_burn_in_iterations
— Number of burn-in iterations (default:0
)
CNN-LSTM on IMDB:
# Temperature = 1
python3 run_hmc.py --seed=1 --weight_decay=40. --temperature=1. \
--dir=runs/hmc/imdb/ --dataset_name=imdb --model_name=cnn_lstm \
--use_float64 --step_size=1e-5 --trajectory_len=0.24 \
--max_num_leapfrog_steps=30000
# Temperature = 0.3
python3 run_hmc.py --seed=1 --weight_decay=40. --temperature=0.3 \
--dir=runs/hmc/imdb/ --dataset_name=imdb --model_name=cnn_lstm \
--use_float64 --step_size=3e-6 --trajectory_len=0.136 \
--max_num_leapfrog_steps=46000
# Temperature = 0.1
python3 run_hmc.py --seed=1 --weight_decay=40. --temperature=0.1 \
--dir=runs/hmc/imdb/ --dataset_name=imdb --model_name=cnn_lstm \
--use_float64 --step_size=1e-6 --trajectory_len=0.078 \
--max_num_leapfrog_steps=90000
# Temperature = 0.03
python3 run_hmc.py --seed=1 --weight_decay=40. --temperature=0.03 \
--dir=runs/hmc/imdb/ --dataset_name=imdb --model_name=cnn_lstm \
--use_float64 --step_size=1e-6 --trajectory_len=0.043 \
--max_num_leapfrog_steps=45000
We ran these commands on a machine with 8 NVIDIA Tesla V-100 GPUs.
MLP on a subset of 160 datapoints from MNIST:
python3 run_hmc.py --seed=0 --weight_decay=1. --temperature=1. \
--dir=runs/hmc/mnist_subset160 --dataset_name=mnist \
--model_name=mlp_classification --step_size=3.e-5 --trajectory_len=1.5 \
--num_iterations=100 --max_num_leapfrog_steps=50000 \
--num_burn_in_iterations=10 --subset_train_to=160
This script can be ran on a single GPU.
Note: we run HMC on CIFAR-10 on TPU pod with 512 TPU devices with a modified version of the code that we will release soon.
To run SGD, you can use the run_sgd.py
training script. Arguments:
init_step_size
— Initial SGD step size; we use a cosine schedulenum_epochs
— total number of SGD epochs iterationsbatch_size
— batch sizeeval_freq
— frequency of evaluation (epochs)save_freq
— frequency of checkpointing (epochs)momentum_decay
— momentum decay parameter for SGD
ResNet-20-FRN on CIFAR-10:
python3 run_sgd.py --seed=1 --weight_decay=10 --dir=runs/sgd/cifar10/ \
--dataset_name=cifar10 --model_name=resnet20_frn_swish \
--init_step_size=3e-7 --num_epochs=500 --eval_freq=10 --batch_size=80 \
--save_freq=500 --subset_train_to=40960
ResNet-20-FRN on CIFAR-100:
python3 run_sgd.py --seed=1 --weight_decay=10 --dir=runs/sgd/cifar100/ \
--dataset_name=cifar100 --model_name=resnet20_frn_swish \
--init_step_size=1e-6 --num_epochs=500 --eval_freq=10 --batch_size=80 \
--save_freq=500 --subset_train_to=40960
CNN-LSTM on IMDB:
python3 run_sgd.py --seed=1 --weight_decay=3. --dir=runs/sgd/imdb/ \
--dataset_name=imdb --model_name=cnn_lstm --init_step_size=3e-7 \
--num_epochs=500 --eval_freq=10 --batch_size=80 --save_freq=500
To train a deep ensemble, we simply train multiple copies of SGD with different random seeds.
To run SGMCMC variations, you can use the run_sgmcmc.py
training script.
It shares command line arguments with SGD, but also introduces the
following arguments:
-
preconditioner
— choice of preconditioner (None
orRMSprop
; default:None
) -
step_size_schedule
— choice step size schedule (constant
orcyclical
); constant sets the step size tofinal_step_size
after a cosine burn-in fornum_burnin_epochs
epochs.cyclical
uses a constant burn-in fornum_burnin_epochs
epochs and then a cosine cyclical schedule (default:constant
) -
num_burnin_epochs
— number of epochs before final lr is reached -
final_step_size
— final step size (used only with constant schedule; default:init_step_size
) -
step_size_cycle_length_epochs
— cycle length (epochs; used only with cyclic schedule; default:50
) -
save_all_ensembled
— save all the networks that are ensembled -
ensemble_freq
— frequency of ensembling the iterates (epochs; default:10
)
ResNet-20-FRN on CIFAR-10:
# SGLD
python3 run_sgmcmc.py --seed=1 --weight_decay=5. --dir=runs/sgmcmc/cifar10/ \
--dataset_name=cifar10 --model_name=resnet20_frn_swish --init_step_size=1e-6 \
--final_step_size=1e-6 --num_epochs=10000 --num_burnin_epochs=1000 \
--eval_freq=10 --batch_size=80 --save_freq=10 --momentum=0. \
--subset_train_to=40960
# SGHMC
python3 run_sgmcmc.py --seed=1 --weight_decay=5 --dir=runs/sgmcmc/cifar10/ \
--dataset_name=cifar10 --model_name=resnet20_frn_swish --init_step_size=3e-7 \
--final_step_size=3e-7 --num_epochs=10000 --num_burnin_epochs=1000 \
--eval_freq=10 --batch_size=80 --save_freq=10 --subset_train_to=40960 \
--momentum=0.9
# SGHMC-CLR
python3 run_sgmcmc.py --seed=1 --weight_decay=5 --dir=runs/sgmcmc/cifar10/ \
--dataset_name=cifar10 --model_name=resnet20_frn_swish --init_step_size=3e-7 \
--num_epochs=10000 --num_burnin_epochs=1000 --step_size_schedule=cyclical \
--step_size_cycle_length_epochs=50 --ensemble_freq=50 --eval_freq=10 \
--batch_size=80 --save_freq=1000 --subset_train_to=40960 \
--preconditioner=None --momentum=0.95 --eval_freq=10 --save_all_ensembled
# SGHMC-CLR-Prec
python3 run_sgmcmc.py --seed=1 --weight_decay=5 --dir=runs/sghmc/cifar10/ \
--dataset_name=cifar10 --model_name=resnet20_frn_swish --init_step_size=3e-5 \
--num_epochs=10000 --num_burnin_epochs=1000 --step_size_schedule=cyclical \
--step_size_cycle_length_epochs=50 --ensemble_freq=50 --eval_freq=10 \
--batch_size=80 --save_freq=50 --subset_train_to=40960 \
--preconditioner=RMSprop --momentum=0.95 --eval_freq=10 --save_all_ensembled
To run mean field variational inference (MFVI), you can use the run_mfvi.py
training script. It shares command line arguments with SGD, but also introduces
the following arguments:
optimizer
— choice of optimizer (SGD
orAdam
; default: SGD)vi_sigma_init
— initial value of the standard deviation over the weights in MFVI (default: 1e-3)vi_ensemble_size
— size of the ensemble sampled in the VI evaluation (default: 20)mean_init_checkpoint
— SGD checkpoint to use for initialization of the mean of the MFVI approximation
ResNet-20-FRN on CIFAR-10 or CIFAR-100:
python3 run_vi.py --seed=11 --weight_decay=5. --dir=runs/vi/cifar100/ \
--dataset_name=[cifar10 | cifar100] --model_name=resnet20_frn_swish \
--init_step_size=1e-4 --num_epochs=300 --eval_freq=10 --batch_size=80 \
--save_freq=300 --subset_train_to=40960 --optimizer=Adam \
--vi_sigma_init=0.01 --temperature=1. --vi_ensemble_size=20 \
--mean_init_checkpoint=<path-to-sgd-solution>
CNN-LSTM on IMDB:
python3 run_vi.py --seed=11 --weight_decay=5. --dir=runs/vi/imdb/ \
--dataset_name=imdb --model_name=cnn_lstm --init_step_size=1e-4 \
--num_epochs=500 --eval_freq=10 --batch_size=80 --save_freq=200 \
--optimizer=Adam --vi_sigma_init=0.01 --temperature=1. --vi_ensemble_size=20 \
--mean_init_checkpoint=<path-to-sgd-solution>
You can produce posterior density visualizations similar to the ones
presented in the paper using the makemake_posterior_surface_plot.py
script. Arguments:
limit_bottom
— limit of the loss surface visualization along the vertical direction at the bottom (defaul:-0.25
)limit_top
— limit of the loss surface visualization along the vertical direction at the top (defaul:-0.25
)limit_left
— limit of the loss surface visualization along the horizontal direction on the left (defaul:1.25
)limit_right
— limit of the loss surface visualization along the horizontal direction on the right (defaul:1.25
)grid_size
— number of grid points in each direction (default:20
)checkpoint1
— path to the first checkpointcheckpoint2
— path to the second checkpointcheckpoint3
— path to the third checkpoint
The script visualizes the posterior log-density, log-likelihood and log-prior in the plane containing the three provided checkpoints.
CNN-LSTM on IMDB:
python3 make_posterior_surface_plot.py --weight_decay=40 --temperature=1. \
--dir=runs/surface_plots/imdb/ --model_name=cnn_lstm --dataset_name=imdb \
--checkpoint1=<ckpt1> --checkpoint2=<ckpt2> --checkpoint3=<ckpt3>
--limit_bottom=-0.75 --limit_left=-0.75 --limit_right=1.75 --limit_top=1.75 \
--grid_size=50