Skip to content

Commit

Permalink
Merge pull request optuna#2904 from toshihikoyanase/add-unit-test-for…
Browse files Browse the repository at this point in the history
…-conditional-objective-function

Add test case of samplers for conditional objective function
  • Loading branch information
keisuke-umezawa authored Nov 18, 2021
2 parents bc5a49c + b974dcc commit 2b0be2c
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/samplers_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,40 @@ def sample() -> List[float]:
assert isinstance(param_value, float)


@parametrize_sampler
def test_conditional_sample_independent(sampler_class: Callable[[], BaseSampler]) -> None:
# This test case reproduces the error reported in #2734.
# See https://github.com/optuna/optuna/pull/2734#issuecomment-857649769.

study = optuna.study.create_study(sampler=sampler_class())
categorical_distribution = CategoricalDistribution(choices=["x", "y"])
dependent_distribution = CategoricalDistribution(choices=["a", "b"])

study.add_trial(
optuna.create_trial(
params={"category": "x", "x": "a"},
distributions={"category": categorical_distribution, "x": dependent_distribution},
value=0.1,
)
)

study.add_trial(
optuna.create_trial(
params={"category": "y", "y": "b"},
distributions={"category": categorical_distribution, "y": dependent_distribution},
value=0.1,
)
)

_trial = _create_new_trial(study)
category = study.sampler.sample_independent(
study, _trial, "category", categorical_distribution
)
assert category in ["x", "y"]
value = study.sampler.sample_independent(study, _trial, category, dependent_distribution)
assert value in ["a", "b"]


def _create_new_trial(study: Study) -> FrozenTrial:

trial_id = study._storage.create_new_trial(study._study_id)
Expand Down

0 comments on commit 2b0be2c

Please sign in to comment.