forked from songlei00/bbo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: using nsgaii for BO acqf optimization
- Loading branch information
songlei
committed
Oct 29, 2024
1 parent
404912c
commit da0d042
Showing
6 changed files
with
224 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from botorch.acquisition import ( | ||
qExpectedImprovement, | ||
qUpperConfidenceBound, | ||
qProbabilityOfImprovement, | ||
qLogExpectedImprovement | ||
) | ||
|
||
|
||
def acqf_factory(acqf_type, model, train_X, train_Y): | ||
if acqf_type == 'qEI': | ||
acqf = qExpectedImprovement(model, train_Y.max()) | ||
elif acqf_type == 'qUCB': | ||
acqf = qUpperConfidenceBound(model, beta=0.18) | ||
elif acqf_type == 'qPI': | ||
acqf = qProbabilityOfImprovement(model, train_Y.max()) | ||
elif acqf_type == 'qlogEI': | ||
acqf = qLogExpectedImprovement(model, train_Y.max()) | ||
else: | ||
raise NotImplementedError | ||
return acqf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import List, Callable | ||
|
||
from torch import Tensor | ||
|
||
from bbo.benchmarks.experimenters.base import BaseExperimenter | ||
from bbo.utils.problem_statement import ProblemStatement | ||
from bbo.utils.converters.torch_converter import TorchTrialConverter | ||
from bbo.utils.trial import Trial | ||
|
||
|
||
class TorchExperimenter(BaseExperimenter): | ||
def __init__( | ||
self, | ||
impl: Callable[[Tensor], Tensor], | ||
problem_statement: ProblemStatement, | ||
): | ||
self._dim = problem_statement.search_space.num_parameters() | ||
self._impl = impl | ||
self._problem_statement = problem_statement | ||
|
||
self._converter = TorchTrialConverter.from_problem( | ||
problem_statement, scale=False, onehot_embed=False, | ||
) | ||
|
||
def evaluate(self, suggestions: List[Trial]): | ||
features = self._converter.to_features(suggestions) | ||
m = self._impl(features) | ||
metrics = self._converter.to_metrics(m) | ||
for suggestion, m in zip(suggestions, metrics): | ||
suggestion.complete(m) | ||
return suggestions | ||
|
||
def problem_statement(self) -> ProblemStatement: | ||
return self._problem_statement |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import Dict, Sequence, List, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from bbo.utils.converters.base import ( | ||
BaseInputConverter, | ||
BaseOutputConverter, | ||
BaseTrialConverter, | ||
) | ||
from bbo.utils.converters.converter import ArrayTrialConverter, NumpyArraySpec | ||
from bbo.utils.problem_statement import ProblemStatement | ||
from bbo.utils.metric_config import MetricInformation | ||
from bbo.utils.trial import ParameterDict, MetricDict, Trial | ||
|
||
|
||
class TorchTrialConverter(BaseTrialConverter): | ||
def __init__( | ||
self, | ||
input_converters: Sequence[BaseInputConverter], | ||
output_converters: Sequence[BaseOutputConverter], | ||
): | ||
self._impl = ArrayTrialConverter(input_converters, output_converters) | ||
|
||
@classmethod | ||
def from_problem( | ||
cls, | ||
problem: ProblemStatement, | ||
*, | ||
scale: bool = True, | ||
onehot_embed: bool = False, | ||
): | ||
converter = cls([], []) | ||
converter._impl = ArrayTrialConverter.from_problem( | ||
problem, scale=scale, onehot_embed=onehot_embed | ||
) | ||
return converter | ||
|
||
def convert(self, trials: Sequence[Trial]) -> Tuple[Tensor, Tensor]: | ||
return self.to_features(trials), self.to_labels(trials) | ||
|
||
def to_features(self, trials: Sequence[Trial]) -> Tensor: | ||
return torch.from_numpy(self._impl.to_features(trials)) | ||
|
||
def to_labels(self, trials: Sequence[Trial]) -> Tensor: | ||
return torch.from_numpy(self._impl.to_labels(trials)) | ||
|
||
def to_trials(self, features: Tensor, labels: Tensor=None) -> Sequence[Trial]: | ||
features = features.detach().numpy() | ||
if labels is not None: | ||
labels = labels.detach().numpy() | ||
return self._impl.to_trials(features, labels) | ||
|
||
def to_parameters(self, features: Tensor) -> List[ParameterDict]: | ||
return self._impl.to_parameters(features.detach().numpy()) | ||
|
||
def to_metrics(self, labels: Tensor) -> List[MetricDict]: | ||
return self._impl.to_metrics(labels.detach().numpy()) | ||
|
||
@property | ||
def output_spec(self) -> Dict[str, NumpyArraySpec]: | ||
return {k: v.output_spec for k, v in self._impl.input_converter_dict.items()} | ||
|
||
@property | ||
def metric_spec(self) -> Dict[str, MetricInformation]: | ||
return {k: v.metric_information for k, v in self._impl.output_converter_dict.items()} |