Skip to content

Commit

Permalink
Merge branch 'master' into test-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 authored Nov 11, 2022
2 parents 2459816 + 8917a51 commit ff9f5b6
Show file tree
Hide file tree
Showing 22 changed files with 205 additions and 102 deletions.
16 changes: 12 additions & 4 deletions optuna/_hypervolume/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from optuna._hypervolume.base import BaseHypervolume # NOQA
from optuna._hypervolume.utils import _compute_2d # NOQA
from optuna._hypervolume.utils import _compute_2points_volume # NOQA
from optuna._hypervolume.wfg import WFG # NOQA
from optuna._hypervolume.base import BaseHypervolume
from optuna._hypervolume.utils import _compute_2d
from optuna._hypervolume.utils import _compute_2points_volume
from optuna._hypervolume.wfg import WFG


__all__ = [
"BaseHypervolume",
"_compute_2d",
"_compute_2points_volume",
"WFG",
]
5 changes: 4 additions & 1 deletion optuna/importance/_fanova/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from optuna.importance._fanova._evaluator import FanovaImportanceEvaluator # NOQA
from optuna.importance._fanova._evaluator import FanovaImportanceEvaluator


__all__ = ["FanovaImportanceEvaluator"]
96 changes: 63 additions & 33 deletions optuna/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,37 @@
}


__all__ = list(_import_structure.keys()) + sum(_import_structure.values(), [])


if TYPE_CHECKING:
from optuna.integration.allennlp import AllenNLPExecutor # NOQA
from optuna.integration.allennlp import AllenNLPPruningCallback # NOQA
from optuna.integration.botorch import BoTorchSampler # NOQA
from optuna.integration.catalyst import CatalystPruningCallback # NOQA
from optuna.integration.catboost import CatBoostPruningCallback # NOQA
from optuna.integration.chainer import ChainerPruningExtension # NOQA
from optuna.integration.chainermn import ChainerMNStudy # NOQA
from optuna.integration.cma import CmaEsSampler # NOQA
from optuna.integration.cma import PyCmaSampler # NOQA
from optuna.integration.fastaiv1 import FastAIV1PruningCallback # NOQA
from optuna.integration.fastaiv2 import FastAIPruningCallback # NOQA
from optuna.integration.fastaiv2 import FastAIV2PruningCallback # NOQA
from optuna.integration.keras import KerasPruningCallback # NOQA
from optuna.integration.lightgbm import LightGBMPruningCallback # NOQA
from optuna.integration.lightgbm import LightGBMTuner # NOQA
from optuna.integration.lightgbm import LightGBMTunerCV # NOQA
from optuna.integration.mlflow import MLflowCallback # NOQA
from optuna.integration.mxnet import MXNetPruningCallback # NOQA
from optuna.integration.pytorch_distributed import TorchDistributedTrial # NOQA
from optuna.integration.pytorch_ignite import PyTorchIgnitePruningHandler # NOQA
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback # NOQA
from optuna.integration.shap import ShapleyImportanceEvaluator # NOQA
from optuna.integration.sklearn import OptunaSearchCV # NOQA
from optuna.integration.skopt import SkoptSampler # NOQA
from optuna.integration.skorch import SkorchPruningCallback # NOQA
from optuna.integration.tensorboard import TensorBoardCallback # NOQA
from optuna.integration.tensorflow import TensorFlowPruningHook # NOQA
from optuna.integration.tfkeras import TFKerasPruningCallback # NOQA
from optuna.integration.wandb import WeightsAndBiasesCallback # NOQA
from optuna.integration.xgboost import XGBoostPruningCallback # NOQA
from optuna.integration.allennlp import AllenNLPExecutor
from optuna.integration.allennlp import AllenNLPPruningCallback
from optuna.integration.botorch import BoTorchSampler
from optuna.integration.catalyst import CatalystPruningCallback
from optuna.integration.catboost import CatBoostPruningCallback
from optuna.integration.chainer import ChainerPruningExtension
from optuna.integration.chainermn import ChainerMNStudy
from optuna.integration.cma import CmaEsSampler
from optuna.integration.cma import PyCmaSampler
from optuna.integration.fastaiv1 import FastAIV1PruningCallback
from optuna.integration.fastaiv2 import FastAIPruningCallback
from optuna.integration.fastaiv2 import FastAIV2PruningCallback
from optuna.integration.keras import KerasPruningCallback
from optuna.integration.lightgbm import LightGBMPruningCallback
from optuna.integration.lightgbm import LightGBMTuner
from optuna.integration.lightgbm import LightGBMTunerCV
from optuna.integration.mlflow import MLflowCallback
from optuna.integration.mxnet import MXNetPruningCallback
from optuna.integration.pytorch_distributed import TorchDistributedTrial
from optuna.integration.pytorch_ignite import PyTorchIgnitePruningHandler
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback
from optuna.integration.shap import ShapleyImportanceEvaluator
from optuna.integration.sklearn import OptunaSearchCV
from optuna.integration.skopt import SkoptSampler
from optuna.integration.skorch import SkorchPruningCallback
from optuna.integration.tensorboard import TensorBoardCallback
from optuna.integration.tensorflow import TensorFlowPruningHook
from optuna.integration.tfkeras import TFKerasPruningCallback
from optuna.integration.wandb import WeightsAndBiasesCallback
from optuna.integration.xgboost import XGBoostPruningCallback
else:

class _IntegrationModule(ModuleType):
Expand Down Expand Up @@ -107,3 +104,36 @@ def _get_module(self, module_name: str) -> ModuleType:
return importlib.import_module("." + module_name, self.__name__)

sys.modules[__name__] = _IntegrationModule(__name__)

__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"BoTorchSampler",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerPruningExtension",
"ChainerMNStudy",
"CmaEsSampler",
"PyCmaSampler",
"MLflowCallback",
"WeightsAndBiasesCallback",
"KerasPruningCallback",
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
"TorchDistributedTrial",
"PyTorchIgnitePruningHandler",
"PyTorchLightningPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"MXNetPruningCallback",
"SkoptSampler",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"FastAIPruningCallback",
]
10 changes: 6 additions & 4 deletions optuna/integration/_lightgbm_tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from optuna.integration._lightgbm_tuner.optimize import _imports
from optuna.integration._lightgbm_tuner.optimize import LightGBMTuner
from optuna.integration._lightgbm_tuner.optimize import LightGBMTunerCV # NOQA
from optuna.integration._lightgbm_tuner.optimize import LightGBMTunerCV


if _imports.is_successful():
from optuna.integration._lightgbm_tuner.sklearn import LGBMClassifier # NOQA
from optuna.integration._lightgbm_tuner.sklearn import LGBMModel # NOQA
from optuna.integration._lightgbm_tuner.sklearn import LGBMRegressor # NOQA
from optuna.integration._lightgbm_tuner.sklearn import LGBMClassifier
from optuna.integration._lightgbm_tuner.sklearn import LGBMModel
from optuna.integration._lightgbm_tuner.sklearn import LGBMRegressor

__all__ = ["LightGBMTuner", "LightGBMTunerCV", "LGBMClassifier", "LGBMModel", "LGBMRegressor"]


def train(*args: Any, **kwargs: Any) -> Any:
Expand Down
9 changes: 6 additions & 3 deletions optuna/integration/allennlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from optuna.integration.allennlp._dump_best_config import dump_best_config # NOQA
from optuna.integration.allennlp._executor import AllenNLPExecutor # NOQA
from optuna.integration.allennlp._pruner import AllenNLPPruningCallback # NOQA
from optuna.integration.allennlp._dump_best_config import dump_best_config
from optuna.integration.allennlp._executor import AllenNLPExecutor
from optuna.integration.allennlp._pruner import AllenNLPPruningCallback


__all__ = ["dump_best_config", "AllenNLPExecutor", "AllenNLPPruningCallback"]
2 changes: 1 addition & 1 deletion optuna/integration/allennlp/_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
):
_imports.check()

if version.parse(allennlp.__version__) < version.parse("2.0.0"):
if version.parse(allennlp.__version__) < version.parse("2.0.0"): # type: ignore
raise ImportError(
"`AllenNLPPruningCallback` requires AllenNLP>=v2.0.0."
"If you want to use a callback with an old version of AllenNLP, "
Expand Down
10 changes: 6 additions & 4 deletions optuna/integration/chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def __init__(
)

@staticmethod
def _get_float_value(observation_value: Union[float, "chainer.Variable"]) -> float:
def _get_float_value(
observation_value: Union[float, "chainer.Variable"] # type: ignore
) -> float:

_imports.check()

try:
if isinstance(observation_value, chainer.Variable):
if isinstance(observation_value, chainer.Variable): # type: ignore
return float(observation_value.data) # type: ignore
else:
return float(observation_value)
Expand All @@ -79,11 +81,11 @@ def _get_float_value(observation_value: Union[float, "chainer.Variable"]) -> flo
"{} cannot be cast to float.".format(type(observation_value))
) from None

def _observation_exists(self, trainer: "chainer.training.Trainer") -> bool:
def _observation_exists(self, trainer: "chainer.training.Trainer") -> bool: # type: ignore

return self._pruner_trigger(trainer) and self._observation_key in trainer.observation

def __call__(self, trainer: "chainer.training.Trainer") -> None:
def __call__(self, trainer: "chainer.training.Trainer") -> None: # type: ignore

if not self._observation_exists(trainer):
return
Expand Down
12 changes: 7 additions & 5 deletions optuna/integration/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@


with try_import() as _imports:
import lightgbm as lgb # NOQA
from lightgbm.callback import CallbackEnv # NOQA
import lightgbm as lgb
from lightgbm.callback import CallbackEnv

# Attach lightgbm API.
if _imports.is_successful():
# To pass tests/integration_tests/lightgbm_tuner_tests/test_optimize.py.
from lightgbm import Dataset # NOQA
from lightgbm import Dataset

from optuna.integration._lightgbm_tuner import LightGBMTuner # NOQA
from optuna.integration._lightgbm_tuner import LightGBMTunerCV # NOQA
from optuna.integration._lightgbm_tuner import LightGBMTuner
from optuna.integration._lightgbm_tuner import LightGBMTunerCV

_names_from_tuners = ["train", "LGBMModel", "LGBMClassifier", "LGBMRegressor"]

Expand All @@ -36,6 +36,8 @@
setattr(sys.modules[__name__], "LightGBMTuner", tuner.__dict__["LightGBMTuner"])
setattr(sys.modules[__name__], "LightGBMTunerCV", tuner.__dict__["LightGBMTunerCV"])

__all__ = ["Dataset", "LightGBMTuner", "LightGBMTunerCV"]


class LightGBMPruningCallback:
"""Callback for LightGBM to prune unpromising trials.
Expand Down
2 changes: 1 addition & 1 deletion optuna/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def on_init_start(self, trainer: Trainer) -> None:
trainer._accelerator_connector.distributed_backend is not None # type: ignore
)
if self.is_ddp_backend:
if version.parse(pl.__version__) < version.parse("1.5.0"):
if version.parse(pl.__version__) < version.parse("1.5.0"): # type: ignore
raise ValueError("PyTorch Lightning>=1.5.0 is required in DDP.")
if not (
isinstance(self._trial.study._storage, _CachedStorage)
Expand Down
24 changes: 17 additions & 7 deletions optuna/logging.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import logging
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from logging import CRITICAL
from logging import DEBUG
from logging import ERROR
from logging import FATAL
from logging import INFO
from logging import WARN
from logging import WARNING
import threading
from typing import Optional

import colorlog


__all__ = [
"CRITICAL",
"DEBUG",
"ERROR",
"FATAL",
"INFO",
"WARN",
"WARNING",
]

_lock: threading.Lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None

Expand Down
18 changes: 13 additions & 5 deletions optuna/multi_objective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from optuna._imports import _LazyImport
from optuna.multi_objective import samplers # NOQA
from optuna.multi_objective import study # NOQA
from optuna.multi_objective import trial # NOQA
from optuna.multi_objective.study import create_study # NOQA
from optuna.multi_objective.study import load_study # NOQA
from optuna.multi_objective import samplers
from optuna.multi_objective import study
from optuna.multi_objective import trial
from optuna.multi_objective.study import create_study
from optuna.multi_objective.study import load_study


visualization = _LazyImport("optuna.multi_objective.visualization")

__all__ = [
"samplers",
"study",
"trial",
"create_study",
"load_study",
]
19 changes: 14 additions & 5 deletions optuna/multi_objective/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from optuna.multi_objective.samplers._adapter import _MultiObjectiveSamplerAdapter # NOQA
from optuna.multi_objective.samplers._base import BaseMultiObjectiveSampler # NOQA
from optuna.multi_objective.samplers._motpe import MOTPEMultiObjectiveSampler # NOQA
from optuna.multi_objective.samplers._nsga2 import NSGAIIMultiObjectiveSampler # NOQA
from optuna.multi_objective.samplers._random import RandomMultiObjectiveSampler # NOQA
from optuna.multi_objective.samplers._adapter import _MultiObjectiveSamplerAdapter
from optuna.multi_objective.samplers._base import BaseMultiObjectiveSampler
from optuna.multi_objective.samplers._motpe import MOTPEMultiObjectiveSampler
from optuna.multi_objective.samplers._nsga2 import NSGAIIMultiObjectiveSampler
from optuna.multi_objective.samplers._random import RandomMultiObjectiveSampler


__all__ = [
"_MultiObjectiveSamplerAdapter",
"BaseMultiObjectiveSampler",
"MOTPEMultiObjectiveSampler",
"NSGAIIMultiObjectiveSampler",
"RandomMultiObjectiveSampler",
]
7 changes: 5 additions & 2 deletions optuna/multi_objective/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from optuna.multi_objective.visualization._pareto_front import plot_pareto_front # NOQA
from optuna.visualization import is_available # NOQA
from optuna.multi_objective.visualization._pareto_front import plot_pareto_front
from optuna.visualization import is_available


__all__ = ["plot_pareto_front", "is_available"]
2 changes: 1 addition & 1 deletion optuna/samplers/nsgaii/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import numpy as np

import optuna
from optuna._experimental import ExperimentalWarning
from optuna.distributions import BaseDistribution
from optuna.exceptions import ExperimentalWarning
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.samplers._base import _process_constraints_after_trial
from optuna.samplers._base import BaseSampler
Expand Down
2 changes: 2 additions & 0 deletions optuna/study/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
if not _imports.is_successful():
pd = object # NOQA

__all__ = ["pd"]


def _create_records_and_aggregate_column(
study: "optuna.Study", attrs: Tuple[str, ...]
Expand Down
14 changes: 8 additions & 6 deletions optuna/visualization/_plotly_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from optuna._imports import try_import


with try_import() as _imports: # NOQA
import plotly # NOQA
with try_import() as _imports:
import plotly
from plotly import __version__ as plotly_version
import plotly.graph_objs as go # NOQA
from plotly.graph_objs import Contour # NOQA
from plotly.graph_objs import Scatter # NOQA
from plotly.subplots import make_subplots # NOQA
import plotly.graph_objs as go
from plotly.graph_objs import Contour
from plotly.graph_objs import Scatter
from plotly.subplots import make_subplots

if version.parse(plotly_version) < version.parse("4.0.0"):
raise ImportError(
Expand All @@ -19,3 +19,5 @@
"For further information, please refer to the installation guide of plotly. ",
name="plotly",
)

__all__ = ["_imports", "plotly", "go", "Contour", "Scatter", "make_subplots"]
Loading

0 comments on commit ff9f5b6

Please sign in to comment.