Skip to content

Commit

Permalink
Add batch-mode KG (qKG)
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 27, 2023
1 parent 15b19b6 commit 065a8e9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
4 changes: 2 additions & 2 deletions gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .acquisition import UCB, EI, POI, UE, Thompson, KG
from .batch_acquisition import qEI, qPOI, qUCB
from .batch_acquisition import qEI, qPOI, qUCB, qKG

__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB"]
__all__ = ["UCB", "EI", "POI", "UE", "KG", "Thompson", "qEI", "qPOI", "qUCB", "qKG"]
53 changes: 48 additions & 5 deletions gpax/acquisition/batch_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..models.gp import ExactGP
from ..utils import random_sample_dict
from .base_acq import ei, ucb, poi
from .base_acq import ei, ucb, poi, kg


def compute_batch_acquisition(acquisition_type: Callable,
Expand All @@ -31,7 +31,7 @@ def compute_batch_acquisition(acquisition_type: Callable,
"""
if model.mcmc is None:
raise ValueError("The model needs to be fully Bayesian")

X = X[:, None] if X.ndim < 2 else X

samples = random_sample_dict(model.get_samples(), subsample_size)
Expand Down Expand Up @@ -67,7 +67,7 @@ def qEI(model: Type[ExactGP],
**kwargs) -> jnp.ndarray:
"""
Batch-mode Expected Improvement
Args:
model: trained model
X: new inputs
Expand Down Expand Up @@ -109,7 +109,7 @@ def qUCB(model: Type[ExactGP],
**kwargs) -> jnp.ndarray:
"""
Batch-mode Upper Confidence Bound
Args:
model: trained model
X: new inputs
Expand Down Expand Up @@ -178,4 +178,47 @@ def qPOI(model: Type[ExactGP],
poi, model, X, xi, maximize, noiseless,
maximize_distance=maximize_distance,
n_evals=n_evals, subsample_size=subsample_size,
indices=indices, **kwargs)
indices=indices, **kwargs)


def qKG(model: Type[ExactGP],
X: jnp.ndarray,
n: int = 10,
maximize: bool = False,
noiseless: bool = False,
maximize_distance: bool = False,
n_evals: int = 1,
subsample_size: int = 1,
indices: Optional[jnp.ndarray] = None,
**kwargs) -> jnp.ndarray:
"""
Batch-mode Knowledge Gradient
Args:
model: trained model
X: new inputs
n: number of simulated samples for each point in X
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
maximize_distance:
Selects a subsample with a maximum distance between acq.argmax() points
n_evals:
Number of evaluations (how many times a ramdom subsample is drawn)
when maximizing distance between maxima of different EIs in a batch.
subsample_size:
Size of the subsample from the GP model's MCMC samples.
indices:
Indices of the input points.
Returns:
The computed batch Knowledge Gradient values at the provided input points X.
"""

return compute_batch_acquisition(
kg, model, X, n, maximize, noiseless,
maximize_distance=maximize_distance,
n_evals=n_evals, subsample_size=subsample_size,
indices=indices, **kwargs)
4 changes: 2 additions & 2 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from gpax.utils import get_keys
from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg
from gpax.acquisition import EI, UCB, UE, Thompson, KG
from gpax.acquisition import qEI, qPOI, qUCB
from gpax.acquisition import qEI, qPOI, qUCB, qKG
from gpax.acquisition.penalties import compute_penalty


Expand Down Expand Up @@ -124,7 +124,7 @@ def test_acq_dkl(acq):


@pytest.mark.parametrize("q", [1, 3])
@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB])
@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG])
def test_batched_acq(acq, q):
rng_key = get_keys()[0]
X = onp.random.randn(8,)
Expand Down

0 comments on commit 065a8e9

Please sign in to comment.