Skip to content

Commit

Permalink
Change arguments of BaseErrorEvaluator and classes that inherit fro…
Browse files Browse the repository at this point in the history
…m it (optuna#4607)

* Fix argument of BaseErrorEvaluator.evaluate and its inheritance classes

* Fix tests/test_erroreval according to the argument change of
Errorevaluator

* update terminator according to change of argument of error_evaluator

* update test to cover maximize pattern
  • Loading branch information
cross32768 authored Apr 21, 2023
1 parent 59d9a01 commit ec622ef
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
32 changes: 26 additions & 6 deletions optuna/terminator/erroreval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import abc
from typing import cast

import numpy as np

from optuna._experimental import experimental_class
from optuna.study.study import Study
from optuna.study import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import Trial
from optuna.trial._state import TrialState

Expand All @@ -15,16 +17,30 @@

class BaseErrorEvaluator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def evaluate(self, study: Study) -> float:
def evaluate(
self,
trials: list[FrozenTrial],
study_direction: StudyDirection,
) -> float:
pass


@experimental_class("3.2.0")
class CrossValidationErrorEvaluator(BaseErrorEvaluator):
def evaluate(self, study: Study) -> float:
assert len(study.get_trials(states=(TrialState.COMPLETE,))) > 0
def evaluate(
self,
trials: list[FrozenTrial],
study_direction: StudyDirection,
) -> float:
trials = [trial for trial in trials if trial.state == TrialState.COMPLETE]
assert len(trials) > 0

if study_direction == StudyDirection.MAXIMIZE:
best_trial = max(trials, key=lambda t: cast(float, t.value))
else:
best_trial = min(trials, key=lambda t: cast(float, t.value))

best_trial_attrs = study.best_trial.system_attrs
best_trial_attrs = best_trial.system_attrs
if _CROSS_VALIDATION_SCORES_KEY in best_trial_attrs:
cv_scores = best_trial_attrs[_CROSS_VALIDATION_SCORES_KEY]
else:
Expand Down Expand Up @@ -56,5 +72,9 @@ class StaticErrorEvaluator(BaseErrorEvaluator):
def __init__(self, constant: float) -> None:
self._constant = constant

def evaluate(self, study: Study) -> float:
def evaluate(
self,
trials: list[FrozenTrial],
study_direction: StudyDirection,
) -> float:
return self._constant
4 changes: 3 additions & 1 deletion optuna/terminator/terminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def should_terminate(self, study: Study) -> bool:
trials=study.trials,
study_direction=study.direction,
)
error = self._error_evaluator.evaluate(study)
error = self._error_evaluator.evaluate(
trials=study.trials, study_direction=study.direction
)
should_terminate = regret_bound < error

return should_terminate
38 changes: 23 additions & 15 deletions tests/terminator_tests/test_erroreval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,63 +22,71 @@ def _create_trial(value: float, cv_scores: list[float]) -> FrozenTrial:
)


def test_cross_validation_evaluator() -> None:
study = create_study(direction="minimize")
@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_cross_validation_evaluator(direction: str) -> None:
study = create_study(direction=direction)
sign = 1 if direction == "minimize" else -1
study.add_trials(
[
_create_trial(value=2.0, cv_scores=[1.0, -1.0]), # Second best trial with 1.0 var.
_create_trial(value=1.0, cv_scores=[2.0, -2.0]), # Best trial with 4.0 var.
_create_trial(
value=sign * 2.0, cv_scores=[1.0, -1.0]
), # Second best trial with 1.0 var.
_create_trial(value=sign * 1.0, cv_scores=[2.0, -2.0]), # Best trial with 4.0 var.
]
)

evaluator = CrossValidationErrorEvaluator()
serror = evaluator.evaluate(study)
serror = evaluator.evaluate(study.trials, study.direction)

expected_scale = 1.5
assert serror == math.sqrt(4.0 * expected_scale)


def test_cross_validation_evaluator_without_cv_scores() -> None:
study = create_study(direction="minimize")
@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_cross_validation_evaluator_without_cv_scores(direction: str) -> None:
study = create_study(direction=direction)
study.add_trial(
# Note that the CV score is not reported with the system attr.
create_trial(params={}, distributions={}, value=0.0)
)

evaluator = CrossValidationErrorEvaluator()
with pytest.raises(ValueError):
evaluator.evaluate(study)
evaluator.evaluate(study.trials, study.direction)


def test_report_cross_validation_scores() -> None:
@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_report_cross_validation_scores(direction: str) -> None:
scores = [1.0, 2.0]

study = create_study(direction="minimize")
study = create_study(direction=direction)
trial = study.ask()
report_cross_validation_scores(trial, scores)
study.tell(trial, 0.0)

assert study.trials[0].system_attrs[_CROSS_VALIDATION_SCORES_KEY] == scores


def test_report_cross_validation_scores_with_illegal_scores_length() -> None:
@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_report_cross_validation_scores_with_illegal_scores_length(direction: str) -> None:
scores = [1.0]

study = create_study(direction="minimize")
study = create_study(direction=direction)
trial = study.ask()
with pytest.raises(ValueError):
report_cross_validation_scores(trial, scores)


def test_static_evaluator() -> None:
study = create_study(direction="minimize")
@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_static_evaluator(direction: str) -> None:
study = create_study(direction=direction)
study.add_trials(
[
_create_trial(value=2.0, cv_scores=[1.0, -1.0]),
]
)

evaluator = StaticErrorEvaluator(constant=100.0)
serror = evaluator.evaluate(study)
serror = evaluator.evaluate(study.trials, study.direction)

assert serror == 100.0

0 comments on commit ec622ef

Please sign in to comment.