Skip to content

Commit

Permalink
Merge pull request optuna#2425 from jeromepatel/motpesampler-after-trial
Browse files Browse the repository at this point in the history
`after_trial` implemented for `MOTPESampler`
  • Loading branch information
not522 authored Mar 4, 2021
2 parents 4c96ec6 + c834e06 commit 925554d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions optuna/samplers/_tpe/multi_objective_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -564,6 +565,16 @@ def _calculate_weights_below(
weights_below = np.clip(contributions / np.max(contributions), 0, 1)
return weights_below

def after_trial(
self,
study: optuna.study.Study,
trial: optuna.trial.FrozenTrial,
state: optuna.trial.TrialState,
values: Optional[Sequence[float]],
) -> None:

self._mo_random_sampler.after_trial(study, trial, state, values)


def _calculate_nondomination_rank(loss_vals: np.ndarray) -> np.ndarray:
vecs = loss_vals.copy()
Expand Down
10 changes: 10 additions & 0 deletions tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,13 @@ def test_reseed_rng() -> None:
sampler.reseed_rng()
assert mock_object.call_count == 1
assert original_seed != sampler._rng.seed


def test_call_after_trial_of_mo_random_sampler() -> None:
sampler = MOTPESampler()
study = optuna.create_study(sampler=sampler)
with patch.object(
sampler._mo_random_sampler, "after_trial", wraps=sampler._mo_random_sampler.after_trial
) as mock_object:
study.optimize(lambda _: 1.0, n_trials=1)
assert mock_object.call_count == 1

0 comments on commit 925554d

Please sign in to comment.