Skip to content

Bayesian Probabilistic Numerical Integration with Tree-Based Models (using Bayesian Additive Regression Trees)

License

Notifications You must be signed in to change notification settings

ImperialCollegeLondon/BART-Int

Repository files navigation

Bayesian Probabilistic Numerical Integration with Tree-Based Models

Bayesian Probabilistic Numerical Integration with Tree-Based models (to appear in NeurIPS 2020)

Authors: Harrison Zhu, François-Xavier Briol, Xing Liu, Ruya Kang, Zhichao Shen, and Seth Flaxman

Code directory

The causal-inference branch contains code relating to "Some Links Between Causal Inference and Bayesian Probabilistic Numerical Integration".

.
|
└── README.md
├───├data: The scripts used to scrape survey design data and where to store the data.
    ├───├extract.R: Extract the PUMS survey dataset
    ├───├sample_full.R: Create 2 CSVs train2.csv and candidate2.csv that contains the needed features for the design and candidate points. 
├───├Figures: Folder to hold results plots.
    ├───├genz
        ├───├1: Genz function results 
        ├───├2: Genz function results 
        ├───├3: Genz function results 
        ├───├4: Genz function results 
        ├───├5: Genz function results 
        ├───├6: Genz function results 
        ├───├7: Genz function results 
        ├───├9: Genz function results 
    ├───├survey_design
├───├results: Folder to hold results plots. This is where we store the results generated by our integral approximation functions, as well as the analytical integrals of the benchmark testing functions.
    ├───├genz
        ├───├1: Genz function results 
        ├───├2: Genz function results 
        ├───├3: Genz function results 
        ├───├4: Genz function results 
        ├───├5: Genz function results 
        ├───├6: Genz function results 
        ├───├7: Genz function results 
        ├───├9: Genz function results 
    ├───├survey_design
├───├figures_code
    ├───├draw_step.R: Draws Figure 1 in the paper
    ├───├plot_binary_response.R: Draws Figure 3 in the paper
    ├───├plot_computational_complexity.R: Draws Figure 2 in the paper
    ├───├plot_high_dimensionality.r: Computes the results in Table 2 in the paper
    ├───├plot_posterior_example.R: Draws Figure 4 in the paper/appendix
    ├───├step_design.R: Draws Figure 5 in the paper
    ├───├compute_CV.R: Computes the results in table 1 in the paper
├───├python: Python code for hyperparameter tuning for the GP
    ├───├gp_tune.py: class for GP regression with marginal likelihood maximisation
├───├src
	├───├genz: Genz functions and its integrals
        ├───├analyticalIntegrals.R
        ├───├genz.R
    ├───├BARTInt.R: Implementation of BART-Int
    ├───├GPBQ.R: Implementation of Bayesian Quadrature with Gaussian processes (GP-BQ)
    ├───├monteCarloIntegration.R: Main class of Monte Carlo integration
    ├───├optimise_gp.R: Source file used to optimise the lengthscale using Pytorch with reticulate
    ├───├meanPopulationStudy: Source files used for Bayesian survey design
        ├───├bartMean.R
        ├───├gpMean.R
├───├integrationMain.R: Main class to do BART-Int, GPBQ and Monte Carlo integrations. Tweak your genz functions and parameters here
├───├poptMean_trained_bin.R: computes the ground truth proportions for the survey design problem
├───├saveComputeIntegrals.R: computes the exact integrals for the genz functions
├───├bart_compute_groundtruth.R: computes the ground truth for the survey design

Dependencies

The experiments are tested under Ubuntu18.04 and OSX. Docker images will be published in due course to ensure wider reproducibility.

R dependencies.

    MASS
    cubature
    lhs
    data.tree
    dbarts
    matrixStats
    mvtnorm
    doParallel
    kernlab
    msm
    MCMCglmm
    dbarts_0.9-8
    caret
    reticulate
    rdist

Python dependencies

torch 
gpytorch

Numerical Experiments: Genz Functions

  1. Install all the necessary packages
install.packages(c("MASS", "cubature", "lhs", "data.tree", "matrixStats", "mvtnorm", "doParallel", "kernlab", "msm", "MCMCglmm", "caret", "reticulate", "rdist"))

# an old version of dbarts
packageurl <- "https://cran.r-project.org/src/contrib/Archive/dbarts/dbarts_0.9-8.tar.gz"
install.packages(packageurl, repos=NULL, type="source")

Note that the older version of dbarts is needed as there had been significant changes in the class files for the data structures

Now for the Python dependencies, we will use following

gpytorch
torch

This is done in src/optimise_gp.R by creating a virtualenv with the function install_python_env() using reticulate.

  1. Save the computed integrals
Rscript saveComputeIntegrals.R

This will store the ground truth in results/genz/integrals.csv

  1. To reproduce the benchmark tests, run integrationMain.R with customized inputs. There are 8 arguments in total, of which the last three are optional. The penultimate argument should only be specified when the step function is used (genz_function_number = 7), and is set to 1 if not specified. For example:
Rscript integrationMain.R dimension num_iterations genz_function_number kernel_name sequential_flag (measure) (number_of_jumps_for_step_function) (save_posterior)

where genz_function_number follows the indexing in this documentation for the Genz families. The results will be stored in results, where you can find the .csv and .RData files containing the numerical values and the automatically generated graphs. For example, one could run this

Rscript integrationMain.R 1 2 1 matern32 1 uniform 1 1

and get results for the continuous function, 2 iterations of sequential design, 1 dimension, matern32 kernel, 1 meaning with sequential design, uniform measure, 1 as a placeholder for the argument involving the number of jumps for a step function, and 1 to indicate whether to save the BART posterior samples at each iteration.

As another example with a Gaussian prior and a step function as defined in the appendix of the paper, one could run

Rscript integrationMain.R 1 2 7 matern32 1 gaussian 1 1

For more information about each input, check the first few lines of integrationMain.R.

Results will also be stored in results/genz and Figures/genz.

We ran the following to generate the results for Table 1

for dim in 1 10
    do
        for genz in 1 2 3 4 5 6 7
            do
                Rscript integrationMain.R $dim 20 $genz matern32 uniform 1 1 1
        done
done

Although you can also rewrite integrationMain.R to parallelise each seed.

For Table 2

for dim in 1 10 20 100
    do
        Rscript test_additive.R $dim 1 9 matern32 uniform 0
done

As for the graphs, we provide the scripts in figures_code.

Navigating the results Check the Genz Functions and see the preprint for more information.

  • 1: Continuous
  • 2: Corner Peak
  • 3: Discontinuous
  • 4: Gaussian Peak
  • 5: Oscillatory
  • 6: Product Peak
  • 7: Step function
  • 9: Additive Gaussian

Numerical Experiments: Bayesian Survey Design

  1. Install the dependencies in R. Make sure you are using R 3.5.0 or higher.

  2. Download and process the dataset

Rscript data/extract.R
cd data;unzip 2016csv_pil.zip; cd ..; Rscript data/sample_full.R

This will create train2.csv and candidate2.csv, which store the possible initial design points and the candidate points.

3)[optional] To compute the BART groundtruth, run

Rscript bart_compute_groundtruth.R num_cv_start num_cv_end num_data num_design 1

where num_cv_start and num_cv_end indicate a loop over possible random seeds with seeding num_cv_start, num_cv_start+1,..,num_cv_end, num_data is the number of candidate points and num_design is the number of initial design points. This can also be computed as the user wishes. The ground truths will be stored in results/survey_design. Alternatively, you can also just take the mean of the entire dataset of 454,816 points and that would yield very similar results.

  1. To run the experiments, first navigate to src/survey_design/gpMean.R and change the jitter/nugget term according to what you deem is appropriate. We set it to what we obtain from the output of the maximum marginal likelihood estimator using line 84 in poptMean_trained_bin.R. Note that some small jitter is always needed for numerical stability during the kernel matrix inversion for GP-BQ. Then run
Rscript poptMean_trained_bin.R num_new_surveys num_cv_start num_cv_end num_data num_design

where num_new_surveys is the number of new surveys to query.

For example, we used num_new_surveys=200, num_data=10000, num_design=20.

This will generate and store the results in results/survey_design and Figures/survey_design, where you can find the .csv and .RData files containing the numerical values and the automatically generated graphs.

Note that the experiments can also be easily run using other BART packages such as BART or bartMachine, provided that src/survey_design/bartMean.R is edited so that dbarts::bart is replaced

To get the results for Table 3

for num_cv in $(seq 1 20)
    do
    echo $num_cv
    Rscript src/meanPopulationStudy/poptMean_trained_bin.R 200 $num_cv $num_cv 10000 20 
done

License

MIT License

Copyright (c) 2020 Imperial College London

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

About

Bayesian Probabilistic Numerical Integration with Tree-Based Models (using Bayesian Additive Regression Trees)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published