Skip to content

Commit

Permalink
BaseTuner inplace of SampleClassMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
ches-001 committed May 29, 2023
1 parent b0f1d74 commit 9539427
Show file tree
Hide file tree
Showing 19 changed files with 386 additions and 408 deletions.
3 changes: 2 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .tune_classifier import *
from .tune_regressor import *
from .baseline import *
from .sample import *
from typing import Iterable


__all__: Iterable[str] = [
"baseline.mixin",
"base",
"tests.test_tuners",
"tests.utils",
"ensemble_classifier",
Expand Down
5 changes: 3 additions & 2 deletions baseline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .mixin import *
from .base import *
from typing import Iterable


__all__: Iterable[str] = [
"SampleClassMixin"
"SampleClassMixin",
"BaseTuner"
]
18 changes: 12 additions & 6 deletions baseline/mixin.py → baseline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, Iterable, Callable

@dataclass

class SampleClassMixin:

def _is_space_type(self, space: Iterable, type: Callable) -> bool:
Expand All @@ -19,10 +19,7 @@ def is_valid_float_space(self, space: Iterable) -> bool:
def is_valid_categorical_space(self, space: Iterable) -> bool:
return (not self.is_valid_float_space(space)) and (not self.is_valid_float_space(space))

def _sample_params(self, trial: Optional[Trial]=None) -> Dict[str, Any]:
if trial is None: raise ValueError("Method should be called in an optuna trial study")

def model(self, trial: Optional[Trial]) -> Any:
def _in_trial(self, trial: Optional[Trial]=None) -> Dict[str, Any]:
if trial is None: raise ValueError("Method should be called in an optuna trial study")

def _evaluate_params(self, model_class: Callable, params: Dict[str, Any]):
Expand Down Expand Up @@ -78,5 +75,14 @@ def _evaluate_sampled_model(
return model



@dataclass
class BaseTuner(SampleClassMixin):

model: Any = None

def sample_params(self, trial: Trial) -> Dict[str, Any]:
super()._in_trial(trial)

def sample_model(self, trial: Trial) -> Any:
pass
super()._in_trial(trial)
26 changes: 21 additions & 5 deletions sample.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from baseline import SampleClassMixin
from optuna.trial import Trial
from tune_regressor import regressor_tuning_entities
from tune_classifier import classifier_tuning_entities
import inspect
from baseline import BaseTuner
from optuna.trial import Trial, FrozenTrial
from tune_regressor import regressor_tuning_entities, regressor_tuner_model_class_dict
from tune_classifier import classifier_tuning_entities, classifier_tuner_model_class_dict
from typing import Iterable, Dict, Optional


Expand All @@ -22,7 +23,22 @@ def sample_models_with_params(
else:
search_space: Dict[str, object] = classifier_tuning_entities

tuner_obj: SampleClassMixin = trial.suggest_categorical("model_tuner", list(search_space.values()))
tuner_obj: BaseTuner = trial.suggest_categorical("model_tuner", list(search_space.values()))
model = tuner_obj.sample_model(trial)

return model


def make_sampled_model(best_trial: FrozenTrial, **kwargs):
model_tuner = best_trial.params["model_tuner"]
tuner_model_class_dict = {**regressor_tuner_model_class_dict, **classifier_tuner_model_class_dict}
model_class = tuner_model_class_dict[model_tuner.__class__.__name__]

model_params_names = list(inspect.signature(model_class.__dict__["__init__"]).parameters.keys())
best_params_dict = {
k.replace(f"{model_tuner.__class__.__name__}_", "") : v
for k, v in best_trial.params.items()
if k.replace(f"{model_tuner.__class__.__name__}_", "") in model_params_names
}

return model_class(**best_params_dict)
Loading

0 comments on commit 9539427

Please sign in to comment.