From 82cd81e00a3941e59cc83ab7a30fc42c8cc70784 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 08:37:50 +0000 Subject: [PATCH 01/13] accept nan in report --- optuna/storages/_rdb/models.py | 2 +- optuna/storages/_rdb/storage.py | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index 8781006f5ad..a310100c628 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -413,7 +413,7 @@ class TrialValueModel(BaseModel): trial_value_id = Column(Integer, primary_key=True) trial_id = Column(Integer, ForeignKey("trials.trial_id"), nullable=False) objective = Column(Integer, nullable=False) - value = Column(Float, nullable=False) + value = Column(Float, nullable=True) trial = orm.relationship( TrialModel, backref=orm.backref("values", cascade="all, delete-orphan") diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 5df92a014be..497b5e202d4 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -15,6 +15,7 @@ from typing import Set from typing import Tuple import uuid +import math import alembic.command import alembic.config @@ -49,6 +50,17 @@ _logger = optuna.logging.get_logger(__name__) +def _ensure_not_nan(value): + # Ensure the value is not Nan, which is not supported by MySQL + # if Nan, change it the None + if isinstance(value, tuple) or isinstance(value, list): + return type(value)([None if math.isnan(v) else v for v in value]) + elif isinstance(value, dict): + return {key: None if math.isnan(v) else v for key, v in value.iterms()} + else: + return None if math.isnan(value) else value + + @contextmanager def _create_scoped_session( scoped_session: orm.scoped_session, @@ -766,6 +778,7 @@ def _lift_numerical_limit(value: float) -> float: def set_trial_values(self, trial_id: int, values: Sequence[float]) -> None: + values = _ensure_not_nan(values) with _create_scoped_session(self.scoped_session) as session: trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) @@ -778,7 +791,8 @@ def _set_trial_value_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - value = self._ensure_numerical_limit(value) + if value is not None: + value = self._ensure_numerical_limit(value) trial_value = models.TrialValueModel.find_by_trial_and_objective(trial, objective, session) if trial_value is None: @@ -793,6 +807,7 @@ def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: + intermediate_value = _ensure_not_nan(intermediate_value) with _create_scoped_session(self.scoped_session, True) as session: self._set_trial_intermediate_value_without_commit( session, trial_id, step, intermediate_value @@ -804,7 +819,8 @@ def _set_trial_intermediate_value_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - intermediate_value = self._ensure_numerical_limit(intermediate_value) + if intermediate_value is not None: + intermediate_value = self._ensure_numerical_limit(intermediate_value) trial_intermediate_value = models.TrialIntermediateValueModel.find_by_trial_and_step( trial, step, session From 9b9193cac62ab6649a7dc918274a5d1bc8b7a9b2 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 08:38:05 +0000 Subject: [PATCH 02/13] add tests --- tests/storages_tests/test_storages.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index 18c23f201dc..8bbdda2be94 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -11,6 +11,7 @@ from typing import Tuple from unittest.mock import Mock from unittest.mock import patch +import numpy import pytest @@ -1330,3 +1331,13 @@ def test_read_trials_from_remote_storage(storage_mode: str) -> None: study_id = storage.create_new_study() storage.read_trials_from_remote_storage(study_id) + +@pytest.mark.parametrize("storage_mode", STORAGE_MODES) +def test_report_with_nan(storage_mode: str) -> None: + def objective(trial): + trial.report(float(numpy.nan), 1) + return 1 + + with StorageSupplier(storage_mode) as storage: + study = optuna.create_study(storage=storage) + study.optimize(objective, n_trials=1) From 9153889523a0804c6fc9b8826654ec8a97c7d151 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 08:54:29 +0000 Subject: [PATCH 03/13] fix style --- optuna/storages/_rdb/storage.py | 2 +- tests/storages_tests/test_storages.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 497b5e202d4..a03be37cd47 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -4,6 +4,7 @@ from datetime import datetime import json import logging +import math import os from typing import Any from typing import Callable @@ -15,7 +16,6 @@ from typing import Set from typing import Tuple import uuid -import math import alembic.command import alembic.config diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index 8bbdda2be94..332f7bdc2ee 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -11,8 +11,8 @@ from typing import Tuple from unittest.mock import Mock from unittest.mock import patch -import numpy +import numpy import pytest import optuna @@ -1332,6 +1332,7 @@ def test_read_trials_from_remote_storage(storage_mode: str) -> None: study_id = storage.create_new_study() storage.read_trials_from_remote_storage(study_id) + @pytest.mark.parametrize("storage_mode", STORAGE_MODES) def test_report_with_nan(storage_mode: str) -> None: def objective(trial): From 63947c4f1fe32899cc999a87eeb78aba233dd25f Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 08:57:49 +0000 Subject: [PATCH 04/13] simplify if --- optuna/storages/_rdb/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index a03be37cd47..6af61dc00b9 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -53,7 +53,7 @@ def _ensure_not_nan(value): # Ensure the value is not Nan, which is not supported by MySQL # if Nan, change it the None - if isinstance(value, tuple) or isinstance(value, list): + if isinstance(value, (tuple, list)): return type(value)([None if math.isnan(v) else v for v in value]) elif isinstance(value, dict): return {key: None if math.isnan(v) else v for key, v in value.iterms()} From b8a9c7533f4435a8486432b103cff58203bd392a Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 09:19:30 +0000 Subject: [PATCH 05/13] add type hint --- optuna/storages/_rdb/storage.py | 2 +- tests/storages_tests/test_storages.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 6af61dc00b9..71fe7194e42 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -50,7 +50,7 @@ _logger = optuna.logging.get_logger(__name__) -def _ensure_not_nan(value): +def _ensure_not_nan(value: Any) -> Any: # Ensure the value is not Nan, which is not supported by MySQL # if Nan, change it the None if isinstance(value, (tuple, list)): diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index 332f7bdc2ee..6c367f94da7 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -1335,7 +1335,7 @@ def test_read_trials_from_remote_storage(storage_mode: str) -> None: @pytest.mark.parametrize("storage_mode", STORAGE_MODES) def test_report_with_nan(storage_mode: str) -> None: - def objective(trial): + def objective(trial: optuna.Trial) -> int: trial.report(float(numpy.nan), 1) return 1 From e66c4a752491fa34fc8ecee5790415e662ce0265 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Sun, 27 Feb 2022 09:22:39 +0000 Subject: [PATCH 06/13] fix typo --- optuna/storages/_rdb/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 71fe7194e42..0aff8240433 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -56,7 +56,7 @@ def _ensure_not_nan(value: Any) -> Any: if isinstance(value, (tuple, list)): return type(value)([None if math.isnan(v) else v for v in value]) elif isinstance(value, dict): - return {key: None if math.isnan(v) else v for key, v in value.iterms()} + return {key: None if math.isnan(v) else v for key, v in value.items()} else: return None if math.isnan(value) else value From c1db173cc80043a031d79b08f2991abb4d710f0a Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Mon, 28 Feb 2022 02:20:36 +0000 Subject: [PATCH 07/13] fix type --- optuna/storages/_rdb/storage.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 0aff8240433..2c129836114 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -50,13 +50,11 @@ _logger = optuna.logging.get_logger(__name__) -def _ensure_not_nan(value: Any) -> Any: +def _ensure_not_nan(value: Sequence[float]) -> Sequence[float]: # Ensure the value is not Nan, which is not supported by MySQL # if Nan, change it the None if isinstance(value, (tuple, list)): return type(value)([None if math.isnan(v) else v for v in value]) - elif isinstance(value, dict): - return {key: None if math.isnan(v) else v for key, v in value.items()} else: return None if math.isnan(value) else value From 81457d0a8ff3e65a0c791f5d2690006c2d6d4064 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Mon, 28 Feb 2022 02:33:08 +0000 Subject: [PATCH 08/13] fix type --- optuna/storages/_rdb/storage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 2c129836114..668d0e4d717 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -15,6 +15,7 @@ from typing import Sequence from typing import Set from typing import Tuple +from typing import Union import uuid import alembic.command @@ -50,7 +51,7 @@ _logger = optuna.logging.get_logger(__name__) -def _ensure_not_nan(value: Sequence[float]) -> Sequence[float]: +def _ensure_not_nan(value: Union[float, Sequence[float]]) -> Optional[Union[float, Sequence[float]]]: # Ensure the value is not Nan, which is not supported by MySQL # if Nan, change it the None if isinstance(value, (tuple, list)): From 69bfe849f796121c76a659cadbebbfbff2ec2402 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Tue, 12 Apr 2022 07:16:23 +0000 Subject: [PATCH 09/13] reformat --- optuna/storages/_rdb/storage.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 2cac2b42d57..fba4bfe363b 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -51,7 +51,9 @@ _logger = optuna.logging.get_logger(__name__) -def _ensure_not_nan(value: Union[float, Sequence[float]]) -> Optional[Union[float, Sequence[float]]]: +def _ensure_not_nan( + value: Union[float, Sequence[float]] +) -> Optional[Union[float, Sequence[float]]]: # Ensure the value is not Nan, which is not supported by MySQL # if Nan, change it the None if isinstance(value, (tuple, list)): From 075c594f06aca0b7d4860c6ce481a5a3c4fb8b5f Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Tue, 12 Apr 2022 07:17:27 +0000 Subject: [PATCH 10/13] add union --- optuna/storages/_rdb/storage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index fba4bfe363b..e5a1c2ca489 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -15,6 +15,7 @@ from typing import Optional from typing import Sequence from typing import Set +from typing import Union import uuid import alembic.command From 134c18f8d34bababf0b007d4e1bb094e92e0b373 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Tue, 12 Apr 2022 07:55:30 +0000 Subject: [PATCH 11/13] clean up and add tests --- optuna/storages/_rdb/storage.py | 20 ++++++-------------- tests/storages_tests/test_storages.py | 7 ++++++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index e5a1c2ca489..afaadeb6041 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -52,17 +52,6 @@ _logger = optuna.logging.get_logger(__name__) -def _ensure_not_nan( - value: Union[float, Sequence[float]] -) -> Optional[Union[float, Sequence[float]]]: - # Ensure the value is not Nan, which is not supported by MySQL - # if Nan, change it the None - if isinstance(value, (tuple, list)): - return type(value)([None if math.isnan(v) else v for v in value]) - else: - return None if math.isnan(value) else value - - @contextmanager def _create_scoped_session( scoped_session: orm.scoped_session, @@ -787,7 +776,10 @@ def _ensure_numerical_limit(value: float) -> float: # dialect. Most limiting one is MySQL which in current data # model will store floats as single precision (32 bit). # There is no support for +inf and -inf in this dialect. - return float(np.clip(value, _RDB_MIN_FLOAT, _RDB_MAX_FLOAT)) + if np.isnan(value): + return None + else: + return float(np.clip(value, _RDB_MIN_FLOAT, _RDB_MAX_FLOAT)) @staticmethod def _lift_numerical_limit(value: float) -> float: @@ -797,6 +789,8 @@ def _lift_numerical_limit(value: float) -> float: # https://dev.mysql.com/doc/refman/8.0/en/problems-with-float.html if np.isclose(value, _RDB_MIN_FLOAT) or np.isclose(value, _RDB_MAX_FLOAT): return float(np.sign(value) * float("inf")) + elif value is None: + return float("nan") return value @deprecated( @@ -806,7 +800,6 @@ def _lift_numerical_limit(value: float) -> float: ) def set_trial_values(self, trial_id: int, values: Sequence[float]) -> None: - values = _ensure_not_nan(values) with _create_scoped_session(self.scoped_session) as session: trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) @@ -862,7 +855,6 @@ def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: - intermediate_value = _ensure_not_nan(intermediate_value) with _create_scoped_session(self.scoped_session, True) as session: self._set_trial_intermediate_value_without_commit( session, trial_id, step, intermediate_value diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index d9b3059aa08..179ad613e73 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -13,6 +13,7 @@ from unittest.mock import Mock from unittest.mock import patch +import numpy as np import pytest import optuna @@ -710,6 +711,7 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None: trial_id_1 = storage.create_new_trial(study_id) trial_id_2 = storage.create_new_trial(study_id) trial_id_3 = storage.create_new_trial(storage.create_new_study()) + trial_id_4 = storage.create_new_trial(study_id) # Test setting new values. storage.set_trial_intermediate_value(trial_id_1, 0, 0.3) @@ -717,16 +719,19 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None: storage.set_trial_intermediate_value(trial_id_3, 0, 0.1) storage.set_trial_intermediate_value(trial_id_3, 1, 0.4) storage.set_trial_intermediate_value(trial_id_3, 2, 0.5) + storage.set_trial_intermediate_value(trial_id_4, 0, float("nan")) assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.3, 2: 0.4} assert storage.get_trial(trial_id_2).intermediate_values == {} assert storage.get_trial(trial_id_3).intermediate_values == {0: 0.1, 1: 0.4, 2: 0.5} + print(storage.get_trial(trial_id_4)) + assert np.isnan(storage.get_trial(trial_id_4).intermediate_values[0]) # Test setting existing step. storage.set_trial_intermediate_value(trial_id_1, 0, 0.2) assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.2, 2: 0.4} - non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3) + 1 + non_existent_trial_id = max(trial_id_1, trial_id_2, trial_id_3, trial_id_4) + 1 with pytest.raises(KeyError): storage.set_trial_intermediate_value(non_existent_trial_id, 0, 0.2) From 5b2b14337f31357b0da4b8f9b9b25e583e5d3f2f Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Tue, 12 Apr 2022 08:00:10 +0000 Subject: [PATCH 12/13] fix bugs --- optuna/storages/_rdb/models.py | 4 ++-- optuna/storages/_rdb/storage.py | 12 +++++------- tests/storages_tests/test_storages.py | 1 - 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index a310100c628..3de3aeadeea 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -413,7 +413,7 @@ class TrialValueModel(BaseModel): trial_value_id = Column(Integer, primary_key=True) trial_id = Column(Integer, ForeignKey("trials.trial_id"), nullable=False) objective = Column(Integer, nullable=False) - value = Column(Float, nullable=True) + value = Column(Float, nullable=False) trial = orm.relationship( TrialModel, backref=orm.backref("values", cascade="all, delete-orphan") @@ -449,7 +449,7 @@ class TrialIntermediateValueModel(BaseModel): trial_intermediate_value_id = Column(Integer, primary_key=True) trial_id = Column(Integer, ForeignKey("trials.trial_id"), nullable=False) step = Column(Integer, nullable=False) - intermediate_value = Column(Float, nullable=False) + intermediate_value = Column(Float, nullable=True) trial = orm.relationship( TrialModel, backref=orm.backref("intermediate_values", cascade="all, delete-orphan") diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index afaadeb6041..5904c36c1e2 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -787,10 +787,10 @@ def _lift_numerical_limit(value: float) -> float: # Floats can't be compared for equality because they are # approximate and not stored as exact values. # https://dev.mysql.com/doc/refman/8.0/en/problems-with-float.html - if np.isclose(value, _RDB_MIN_FLOAT) or np.isclose(value, _RDB_MAX_FLOAT): - return float(np.sign(value) * float("inf")) - elif value is None: + if value is None: return float("nan") + elif np.isclose(value, _RDB_MIN_FLOAT) or np.isclose(value, _RDB_MAX_FLOAT): + return float(np.sign(value) * float("inf")) return value @deprecated( @@ -839,8 +839,7 @@ def _set_trial_value_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - if value is not None: - value = self._ensure_numerical_limit(value) + value = self._ensure_numerical_limit(value) trial_value = models.TrialValueModel.find_by_trial_and_objective(trial, objective, session) if trial_value is None: @@ -866,8 +865,7 @@ def _set_trial_intermediate_value_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - if intermediate_value is not None: - intermediate_value = self._ensure_numerical_limit(intermediate_value) + intermediate_value = self._ensure_numerical_limit(intermediate_value) trial_intermediate_value = models.TrialIntermediateValueModel.find_by_trial_and_step( trial, step, session diff --git a/tests/storages_tests/test_storages.py b/tests/storages_tests/test_storages.py index 179ad613e73..fb8221239c6 100644 --- a/tests/storages_tests/test_storages.py +++ b/tests/storages_tests/test_storages.py @@ -724,7 +724,6 @@ def test_set_trial_intermediate_value(storage_mode: str) -> None: assert storage.get_trial(trial_id_1).intermediate_values == {0: 0.3, 2: 0.4} assert storage.get_trial(trial_id_2).intermediate_values == {} assert storage.get_trial(trial_id_3).intermediate_values == {0: 0.1, 1: 0.4, 2: 0.5} - print(storage.get_trial(trial_id_4)) assert np.isnan(storage.get_trial(trial_id_4).intermediate_values[0]) # Test setting existing step. From 4a270c5d7fbd81b74c03ba545d1b36e9d85c8348 Mon Sep 17 00:00:00 2001 From: tianqi xu Date: Tue, 12 Apr 2022 08:17:59 +0000 Subject: [PATCH 13/13] fix for mypy --- optuna/storages/_rdb/storage.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/optuna/storages/_rdb/storage.py b/optuna/storages/_rdb/storage.py index 5904c36c1e2..15f9406cc86 100644 --- a/optuna/storages/_rdb/storage.py +++ b/optuna/storages/_rdb/storage.py @@ -4,7 +4,6 @@ from datetime import datetime import json import logging -import math import os from typing import Any from typing import Callable @@ -15,7 +14,6 @@ from typing import Optional from typing import Sequence from typing import Set -from typing import Union import uuid import alembic.command @@ -769,6 +767,13 @@ def get_trial_param(self, trial_id: int, param_name: str) -> float: return param_value + @staticmethod + def _ensure_not_nan(value: float) -> Optional[float]: + if np.isnan(value): + return None + else: + return value + @staticmethod def _ensure_numerical_limit(value: float) -> float: @@ -776,13 +781,10 @@ def _ensure_numerical_limit(value: float) -> float: # dialect. Most limiting one is MySQL which in current data # model will store floats as single precision (32 bit). # There is no support for +inf and -inf in this dialect. - if np.isnan(value): - return None - else: - return float(np.clip(value, _RDB_MIN_FLOAT, _RDB_MAX_FLOAT)) + return float(np.clip(value, _RDB_MIN_FLOAT, _RDB_MAX_FLOAT)) @staticmethod - def _lift_numerical_limit(value: float) -> float: + def _lift_numerical_limit(value: Optional[float]) -> float: # Floats can't be compared for equality because they are # approximate and not stored as exact values. @@ -865,18 +867,20 @@ def _set_trial_intermediate_value_without_commit( trial = models.TrialModel.find_or_raise_by_id(trial_id, session) self.check_trial_is_updatable(trial_id, trial.state) - intermediate_value = self._ensure_numerical_limit(intermediate_value) + _intermediate_value = self._ensure_not_nan(intermediate_value) + if _intermediate_value is not None: + _intermediate_value = self._ensure_numerical_limit(_intermediate_value) trial_intermediate_value = models.TrialIntermediateValueModel.find_by_trial_and_step( trial, step, session ) if trial_intermediate_value is None: trial_intermediate_value = models.TrialIntermediateValueModel( - trial_id=trial_id, step=step, intermediate_value=intermediate_value + trial_id=trial_id, step=step, intermediate_value=_intermediate_value ) session.add(trial_intermediate_value) else: - trial_intermediate_value.intermediate_value = intermediate_value + trial_intermediate_value.intermediate_value = _intermediate_value def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: