Skip to content

Commit

Permalink
Refitting (arviz-devs#1373)
Browse files Browse the repository at this point in the history
* update sampling wrapper base to better evaluate likelihoods

* add 2 pymc3 refitting examples

* fix typo

* add pystan+xarray refitting notebook

* Add numpyro refitting examples

* black

* add not working pyro example

* reorganize refitting notebooks inside user_guide

* add refitting examples to docs

* add to changelog
  • Loading branch information
OriolAbril authored Jan 13, 2021
1 parent 3878e17 commit fff2b2f
Show file tree
Hide file tree
Showing 10 changed files with 16,828 additions and 53 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
* Fix `pair_plot` for mixed discrete and continuous variables ([1434](https://github.com/arviz-devs/arviz/pull/1434))
* Fix in-sample deviance in `plot_compare` ([1435](https://github.com/arviz-devs/arviz/pull/1435))
* Fix computation of weights in compare ([1438](https://github.com/arviz-devs/arviz/pull/1438))
* Avoid repeated warning in summary ([1442](https://github.com/arviz-devs/arviz/pull/1442))
* Avoid repeated warning in summary ([1442](https://github.com/arviz-devs/arviz/pull/1442))
* Fix hdi failure with boolean array ([1444](https://github.com/arviz-devs/arviz/pull/1444))
* Automatically get the current axes instance for `plt_kde`, `plot_dist` and `plot_hdi` ([1452](https://github.com/arviz-devs/arviz/pull/1452))
* Add grid argument to manually specify the number of rows and columns ([1459](https://github.com/arviz-devs/arviz/pull/1459))
* Switch to `compact=True` by default in our plots ([1468](https://github.com/arviz-devs/arviz/issues/1468))
* `plot_elpd`, avoid modifying the input dict ([1477](https://github.com/arviz-devs/arviz/issues/1477))
* Do not plot divergences in `plot_trace` when `kind=rank_vlines` or `kind=rank_bars` ([1476](https://github.com/arviz-devs/arviz/issues/1476))
* Do not plot divergences in `plot_trace` when `kind=rank_vlines` or `kind=rank_bars` ([1476](https://github.com/arviz-devs/arviz/issues/1476))


### Deprecation
Expand All @@ -42,6 +42,10 @@
* Switch to [MyST](https://myst-parser.readthedocs.io/en/latest/) and [MyST-NB](https://myst-nb.readthedocs.io/en/latest/index.html)
for markdown/notebook parsing in docs ([1406](https://github.com/arviz-devs/arviz/pull/1406))
* Incorporated `input_core_dims` in `hdi` and `plot_hdi` docstrings ([1410](https://github.com/arviz-devs/arviz/pull/1410))
* Add documentation pages about experimental `SamplingWrapper`s usage ([1373](https://github.com/arviz-devs/arviz/pull/1373))

### Experimental
* Modified `SamplingWrapper` base API ([1373](https://github.com/arviz-devs/arviz/pull/1373))

## v0.10.0 (2020 Sep 24)
### New features
Expand Down
113 changes: 66 additions & 47 deletions arviz/wrappers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Base class for sampling wrappers."""
import numpy as np
from xarray import apply_ufunc

# from ..data import InferenceData
from ..stats import wrap_xarray_ufunc as _wrap_xarray_ufunc
Expand All @@ -18,40 +18,75 @@ class SamplingWrapper:
----------
model
The model object used for sampling.
idata_orig: InferenceData, optional
idata_orig : InferenceData, optional
Original InferenceData object.
log_like_fun: callable, optional
log_lik_fun : callable, optional
For simple cases where the pointwise log likelihood is a Python function, this
function will be used to calculate the log likelihood. Otherwise,
``point_log_likelihood`` method must be implemented.
sample_kwargs: dict, optional
``point_log_likelihood`` method must be implemented. It's callback must be
``log_lik_fun(*args, **log_lik_kwargs)`` and will be called using
:func:`wrap_xarray_ufunc` or :func:`xarray:xarray.apply_ufunc` depending
on the value of `is_ufunc`.
For more details on ``args`` or ``log_lik_kwargs`` see the notes and
parameters ``posterior_vars`` and ``log_lik_kwargs``.
is_ufunc : bool, default True
If True, call ``log_lik_fun`` using :func:`xarray:xarray.apply_ufunc` otherwise
use :func:`wrap_xarray_ufunc`.
posterior_vars : list of str, optional
List of variable names to unpack as ``args`` for ``log_lik_fun``. Each string in
the list will be used to retrieve a DataArray from the Dataset in the posterior
group and passed to ``log_lik_fun``.
sample_kwargs : dict, optional
Sampling kwargs are stored as class attributes for their usage in the ``sample``
method.
idata_kwargs: dict, optional
idata_kwargs : dict, optional
kwargs are stored as class attributes to be used in the ``get_inference_data`` method.
log_lik_kwargs : dict, optional
Keyword arguments passed to ``log_lik_fun``.
apply_ufunc_kwargs : dict, optional
Passed to :func:`xarray:xarray.apply_ufunc` or :func:`wrap_xarray_ufunc`.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
Notes
-----
Example of ``log_like_fun`` usage.
"""

def __init__(
self, model, idata_orig=None, log_like_fun=None, sample_kwargs=None, idata_kwargs=None
self,
model,
idata_orig=None,
log_lik_fun=None,
is_ufunc=True,
posterior_vars=None,
sample_kwargs=None,
idata_kwargs=None,
log_lik_kwargs=None,
apply_ufunc_kwargs=None,
):
self.model = model

# if not isinstance(idata_orig, InferenceData) or idata_orig is not None:
# raise TypeError("idata_orig must be of InferenceData type or None")
self.idata_orig = idata_orig

if log_like_fun is None or callable(log_like_fun):
self.log_like_fun = log_like_fun
if log_lik_fun is None or callable(log_lik_fun):
self.log_lik_fun = log_lik_fun
self.is_ufunc = is_ufunc
self.posterior_vars = posterior_vars
else:
raise TypeError("log_like_fun must be a callable object or None")

self.sample_kwargs = {} if sample_kwargs is None else sample_kwargs
self.idata_kwargs = {} if idata_kwargs is None else idata_kwargs
self.log_lik_kwargs = {} if log_lik_kwargs is None else log_lik_kwargs
self.apply_ufunc_kwargs = {} if apply_ufunc_kwargs is None else apply_ufunc_kwargs

def sel_observations(self, idx):
"""Select a subset of the observations in idata_orig.
Expand Down Expand Up @@ -109,29 +144,6 @@ def get_inference_data(self, fitted_model):
"""
raise NotImplementedError("get_inference_data method must be implemented for each subclass")

def point_log_likelihood(self, observation, parameters):
"""Pointwise log likelihood function.
Parameters
----------
observation
Pointwise observation on which to calculate the log likelihood
parameters
Parameters on which the log likelihood is conditioned.
Returns
-------
point_log_likelihood: float
Value of the log likelihood of ``observation`` given ``parameters``
according to ``self.model``
"""
if self.log_like_fun is None:
raise NotImplementedError(
"If log_like_fun is None, point_log_likelihood method must "
"be implemented for each subclass"
)
return self.log_like_fun(observation, parameters)

def log_likelihood__i(self, excluded_obs, idata__i):
r"""Get the log likelilhood samples :math:`\log p_{post(-i)}(y_i)`.
Expand All @@ -141,26 +153,36 @@ def log_likelihood__i(self, excluded_obs, idata__i):
Parameters
----------
excluded_obs
Observations for which to calculate their log likelihood
Observations for which to calculate their log likelihood. The second item from
the tuple returned by `sel_observations` is passed as this argument.
idata__i: InferenceData
Inference results of refitting the data excluding some observations.
Inference results of refitting the data excluding some observations. The
result of `get_inference_data` is used as this argument.
Returns
-------
log_likelihood: xr.Dataarray
Log likelihood of ``excluded_obs`` evaluated at each of the posterior samples
stored in ``idata__i``.
"""
ndraws = idata__i.posterior.dims["draw"]
nchains = idata__i.posterior.dims["chain"]
log_like_idx = _wrap_xarray_ufunc(
lambda pars: self.point_log_likelihood(excluded_obs, pars),
idata__i.posterior.to_array(),
func_kwargs={"out": np.empty((nchains, ndraws))},
ufunc_kwargs={"n_dims": 1, "ravel": False},
input_core_dims=[["variable"]],
if self.log_lik_fun is None:
raise NotImplementedError(
"When `log_like_fun` is not set during class initialization "
"log_likelihood__i method must be overwritten"
)
posterior = idata__i.posterior
arys = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars])
if self.is_ufunc:
ufunc_applier = apply_ufunc
else:
ufunc_applier = _wrap_xarray_ufunc
log_lik_idx = ufunc_applier(
self.log_lik_fun,
*arys,
kwargs=self.log_lik_kwargs,
**self.apply_ufunc_kwargs,
)
return log_like_idx
return log_lik_idx

def _check_method_is_implemented(self, method, *args):
"""Check a given method is implemented."""
Expand Down Expand Up @@ -193,10 +215,7 @@ def check_implemented_methods(self, methods):
"sample",
"get_inference_data",
)
supported_methods_2args = (
"point_log_likelihood",
"log_likelihood__i",
)
supported_methods_2args = ("log_likelihood__i",)
supported_methods = [*supported_methods_1arg, *supported_methods_2args]
bad_methods = [method for method in methods if method not in supported_methods]
if bad_methods:
Expand Down
1 change: 1 addition & 0 deletions doc/source/user_guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ User Guide
==========

.. toctree::
:maxdepth: 2

data_structures
computation
Expand Down
2,772 changes: 2,772 additions & 0 deletions doc/source/user_guide/numpyro_refitting.ipynb

Large diffs are not rendered by default.

4,789 changes: 4,789 additions & 0 deletions doc/source/user_guide/numpyro_refitting_xr_lik.ipynb

Large diffs are not rendered by default.

592 changes: 592 additions & 0 deletions doc/source/user_guide/pymc3_refitting.ipynb

Large diffs are not rendered by default.

5,164 changes: 5,164 additions & 0 deletions doc/source/user_guide/pymc3_refitting_xr_lik.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/source/user_guide/pystan_refitting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
3,423 changes: 3,423 additions & 0 deletions doc/source/user_guide/pystan_refitting_xr_lik.ipynb

Large diffs are not rendered by default.

17 changes: 14 additions & 3 deletions doc/source/user_guide/sampling_wrappers.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# Sampling wrappers
Sampling wrappers allow ArviZ to call PPLs in order to perform a limited
subset of their capabilities and calculate stats and diagnostics that require
refitting the model on different data.

Their implementation is still experimental and may vary in the future. In fact
there are currently two possible approaches when creating sampling wrappers.
The first one delegates all calculations to the PPL
whereas the second one externalizes the computation of the pointwise log
likelihood to the user who is expected to write it with xarray/numpy.

```{toctree}
pystan_refitting
pymc3_refitting
numpyro_refitting
pystan_refitting_xr_lik
pymc3_refitting_xr_lik
numpyro_refitting_xr_lik
```

Examples about sampling wrappers for other libraries will be added soon to
this section.

0 comments on commit fff2b2f

Please sign in to comment.