Skip to content

Commit

Permalink
Fix __init__.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
y0z committed Feb 14, 2024
1 parent cd00cc7 commit 596cb9b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 37 deletions.
70 changes: 36 additions & 34 deletions optuna/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,41 @@
}


__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"BoTorchSampler",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerPruningExtension",
"ChainerMNStudy",
"CmaEsSampler",
"PyCmaSampler",
"DaskStorage",
"MLflowCallback",
"WeightsAndBiasesCallback",
"KerasPruningCallback",
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
"TorchDistributedTrial",
"PyTorchIgnitePruningHandler",
"PyTorchLightningPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"MXNetPruningCallback",
"SkoptSampler",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"FastAIPruningCallback",
]


if TYPE_CHECKING:
from optuna.integration.allennlp import AllenNLPExecutor
from optuna.integration.allennlp import AllenNLPPruningCallback
Expand Down Expand Up @@ -77,6 +112,7 @@ class _IntegrationModule(ModuleType):
imports all submodules and their dependencies (e.g., chainer, keras, lightgbm) all at once.
"""

__all__ = __all__
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

Expand Down Expand Up @@ -113,37 +149,3 @@ def _get_module(self, module_name: str) -> ModuleType:
)

sys.modules[__name__] = _IntegrationModule(__name__)

__all__ = [
"AllenNLPExecutor",
"AllenNLPPruningCallback",
"BoTorchSampler",
"CatalystPruningCallback",
"CatBoostPruningCallback",
"ChainerPruningExtension",
"ChainerMNStudy",
"CmaEsSampler",
"PyCmaSampler",
"DaskStorage",
"MLflowCallback",
"WeightsAndBiasesCallback",
"KerasPruningCallback",
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
"TorchDistributedTrial",
"PyTorchIgnitePruningHandler",
"PyTorchLightningPruningCallback",
"OptunaSearchCV",
"ShapleyImportanceEvaluator",
"SkorchPruningCallback",
"MXNetPruningCallback",
"SkoptSampler",
"TensorBoardCallback",
"TensorFlowPruningHook",
"TFKerasPruningCallback",
"XGBoostPruningCallback",
"FastAIV1PruningCallback",
"FastAIV2PruningCallback",
"FastAIPruningCallback",
]
23 changes: 20 additions & 3 deletions optuna/integration/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from optuna_integration.lightgbm import LightGBMPruningCallback
from optuna_integration.lightgbm import LightGBMTuner
from optuna_integration.lightgbm import LightGBMTunerCV
import os
import sys
from types import ModuleType
from typing import Any

import optuna_integration.lightgbm as lgb


__all__ = [
"LightGBMPruningCallback",
"LightGBMTuner",
"LightGBMTunerCV",
]


class _LightGBMModule(ModuleType):
"""Module class that implements `optuna.integration.lightgbm` package."""

__all__ = __all__
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]

def __getattr__(self, name: str) -> Any:
return lgb.__dict__[name]


sys.modules[__name__] = _LightGBMModule(__name__)

0 comments on commit 596cb9b

Please sign in to comment.