Skip to content

Commit

Permalink
Merge pull request optuna#3348 from belldandyxtq/nan_report
Browse files Browse the repository at this point in the history
Accept `nan` in `trial.report`
  • Loading branch information
himkt authored Apr 19, 2022
2 parents cd70021 + 04f4b36 commit 638bdef
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion optuna/storages/_rdb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class TrialIntermediateValueModel(BaseModel):
trial_intermediate_value_id = Column(Integer, primary_key=True)
trial_id = Column(Integer, ForeignKey("trials.trial_id"), nullable=False)
step = Column(Integer, nullable=False)
intermediate_value = Column(Float, nullable=False)
intermediate_value = Column(Float, nullable=True)

trial = orm.relationship(
TrialModel, backref=orm.backref("intermediate_values", cascade="all, delete-orphan")
Expand Down
21 changes: 16 additions & 5 deletions optuna/storages/_rdb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,13 @@ def get_trial_param(self, trial_id: int, param_name: str) -> float:

return param_value

@staticmethod
def _ensure_not_nan(value: float) -> Optional[float]:
if np.isnan(value):
return None
else:
return value

@staticmethod
def _ensure_numerical_limit(value: float) -> float:

Expand All @@ -777,12 +784,14 @@ def _ensure_numerical_limit(value: float) -> float:
return float(np.clip(value, _RDB_MIN_FLOAT, _RDB_MAX_FLOAT))

@staticmethod
def _lift_numerical_limit(value: float) -> float:
def _lift_numerical_limit(value: Optional[float]) -> float:

# Floats can't be compared for equality because they are
# approximate and not stored as exact values.
# https://dev.mysql.com/doc/refman/8.0/en/problems-with-float.html
if np.isclose(value, _RDB_MIN_FLOAT) or np.isclose(value, _RDB_MAX_FLOAT):
if value is None:
return float("nan")
elif np.isclose(value, _RDB_MIN_FLOAT) or np.isclose(value, _RDB_MAX_FLOAT):
return float(np.sign(value) * float("inf"))
return value

Expand Down Expand Up @@ -858,18 +867,20 @@ def _set_trial_intermediate_value_without_commit(

trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
intermediate_value = self._ensure_numerical_limit(intermediate_value)
_intermediate_value = self._ensure_not_nan(intermediate_value)
if _intermediate_value is not None:
_intermediate_value = self._ensure_numerical_limit(_intermediate_value)

trial_intermediate_value = models.TrialIntermediateValueModel.find_by_trial_and_step(
trial, step, session
)
if trial_intermediate_value is None:
trial_intermediate_value = models.TrialIntermediateValueModel(
trial_id=trial_id, step=step, intermediate_value=intermediate_value
trial_id=trial_id, step=step, intermediate_value=_intermediate_value
)
session.add(trial_intermediate_value)
else:
trial_intermediate_value.intermediate_value = intermediate_value
trial_intermediate_value.intermediate_value = _intermediate_value

def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:

Expand Down
6 changes: 5 additions & 1 deletion tests/storages_tests/test_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from unittest.mock import Mock
from unittest.mock import patch

import numpy as np
import pytest

import optuna
Expand Down Expand Up @@ -710,6 +711,7 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None:
trial_id_1 = storage.create_new_trial(study_id)
trial_id_2 = storage.create_new_trial(study_id)
trial_id_3 = storage.create_new_trial(storage.create_new_study())
trial_id_4 = storage.create_new_trial(study_id)

# Test setting new values.
storage.set_trial_intermediate_value(trial_id_1, 0, 0.3)
Expand All @@ -718,6 +720,7 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None:
storage.set_trial_intermediate_value(trial_id_3, 1, 0.4)
storage.set_trial_intermediate_value(trial_id_3, 2, 0.5)
storage.set_trial_intermediate_value(trial_id_3, 3, float("inf"))
storage.set_trial_intermediate_value(trial_id_4, 0, float("nan"))

assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.3, 2: 0.4}
assert storage.get_trial(trial_id_2).intermediate_values == {}
Expand All @@ -727,12 +730,13 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None:
2: 0.5,
3: float("inf"),
}
assert np.isnan(storage.get_trial(trial_id_4).intermediate_values[0])

# Test setting existing step.
storage.set_trial_intermediate_value(trial_id_1, 0, 0.2)
assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.2, 2: 0.4}

non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3) + 1
non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3, trial_id_4) + 1
with pytest.raises(KeyError):
storage.set_trial_intermediate_value(non_existent_trial_id, 0, 0.2)

Expand Down

0 comments on commit 638bdef

Please sign in to comment.