Skip to content

Commit

Permalink
Fix GridSampler with RetryFailedTrialCallback or enqueue_trial
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 committed Sep 22, 2021
1 parent 5e6ffe6 commit 1f27099
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
9 changes: 9 additions & 0 deletions optuna/samplers/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def sample_relative(
# object is hard to get at the beginning of trial, while we need the access to the object
# to validate the sampled value.

# When the trial is created by RetryFailedTrialCallback or enqueue_trial, we should not
# assign a new grid_id.
if "grid_id" in trial.system_attrs or "fixed_params" in trial.system_attrs:
return {}

target_grids = self._get_unvisited_grid_ids(study)

if len(target_grids) == 0:
Expand Down Expand Up @@ -158,6 +163,10 @@ def sample_independent(
param_distribution: BaseDistribution,
) -> Any:

if "grid_id" not in trial.system_attrs:
message = "You should specify all parameters in enqueue_trial when using GridSampler."
raise ValueError(message)

if param_name not in self._search_space:
message = "The parameter name, {}, is not found in the given grid.".format(param_name)
raise ValueError(message)
Expand Down
38 changes: 38 additions & 0 deletions tests/samplers_tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import optuna
from optuna import samplers
from optuna.samplers._grid import GridValueType
from optuna.storages import RetryFailedTrialCallback
from optuna.trial import Trial


Expand Down Expand Up @@ -166,3 +167,40 @@ def test_has_same_search_space() -> None:

assert not sampler._same_search_space({"x": [3, 2, 1, 0], "y": ["a", "b", "c"]})
assert not sampler._same_search_space({"x": [3, 2], "y": ["a", "b", "c"]})


def test_retried_trial() -> None:
sampler = samplers.GridSampler({"a": [0, 50]})
study = optuna.create_study(sampler=sampler)
trial = study.ask()
trial.suggest_int("a", 0, 100)

callback = RetryFailedTrialCallback()
callback(study, study.trials[0])

study.optimize(lambda trial: trial.suggest_int("a", 0, 100))

assert len(study.trials) == 3
assert study.trials[0].params["a"] == study.trials[1].params["a"]
assert study.trials[0].system_attrs["grid_id"] == study.trials[1].system_attrs["grid_id"]


def test_enqueued_trial() -> None:
sampler = samplers.GridSampler({"a": [0, 50]})
study = optuna.create_study(sampler=sampler)
study.enqueue_trial({"a": 100})

study.optimize(lambda trial: trial.suggest_int("a", 0, 100))

assert len(study.trials) == 3
assert study.trials[0].params["a"] == 100
assert sorted([study.trials[1].params["a"], study.trials[2].params["a"]]) == [0, 50]


def test_enqueued_insufficient_trial() -> None:
sampler = samplers.GridSampler({"a": [0, 50]})
study = optuna.create_study(sampler=sampler)
study.enqueue_trial({})

with pytest.raises(ValueError):
study.optimize(lambda trial: trial.suggest_int("a", 0, 100))

0 comments on commit 1f27099

Please sign in to comment.