forked from blackjax-devs/blackjax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
48 lines (45 loc) · 942 Bytes
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .kernels import (
adaptive_tempered_smc,
elliptical_slice,
ghmc,
hmc,
irmh,
mala,
meads,
mgrad_gaussian,
nuts,
orbital_hmc,
pathfinder_adaptation,
rmh,
sghmc,
sgld,
tempered_smc,
window_adaptation,
)
from .optimizers import dual_averaging, lbfgs
__all__ = [
"dual_averaging", # optimizers
"lbfgs",
"hmc", # mcmc
"mala",
"mgrad_gaussian",
"nuts",
"orbital_hmc",
"rmh",
"irmh",
"elliptical_slice",
"ghmc",
"meads",
"sgld", # stochastic gradient mcmc
"sghmc",
"window_adaptation", # mcmc adaptation
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
"ess", # diagnostics
"rhat",
]
from . import _version
__version__ = _version.get_versions()["version"]