forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create experiment from Python code (microsoft#3111)
- Loading branch information
Showing
34 changed files
with
1,283 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# FIXME: For demonstration only. It should not be here | ||
|
||
from pathlib import Path | ||
|
||
from nni.experiment import Experiment | ||
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner | ||
|
||
tuner = HyperoptTuner('tpe') | ||
|
||
search_space = { | ||
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] }, | ||
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] }, | ||
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] }, | ||
"batch_size": { "_type": "choice", "_value": [16, 32] }, | ||
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] } | ||
} | ||
|
||
experiment = Experiment(tuner, 'local') | ||
experiment.config.experiment_name = 'test' | ||
experiment.config.trial_concurrency = 2 | ||
experiment.config.max_trial_number = 5 | ||
experiment.config.search_space = search_space | ||
experiment.config.trial_command = 'python3 mnist.py' | ||
experiment.config.trial_code_directory = Path(__file__).parent | ||
experiment.config.training_service.use_active_gpu = True | ||
|
||
experiment.run(8081) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .config import * | ||
from .experiment import Experiment | ||
|
||
from .nni_client import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .common import * | ||
from .local import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import copy | ||
import dataclasses | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Type, TypeVar | ||
|
||
from ruamel import yaml | ||
|
||
from . import util | ||
|
||
__all__ = ['ConfigBase', 'PathLike'] | ||
|
||
T = TypeVar('T', bound='ConfigBase') | ||
|
||
PathLike = util.PathLike | ||
|
||
def _is_missing(obj: Any) -> bool: | ||
return isinstance(obj, type(dataclasses.MISSING)) | ||
|
||
class ConfigBase: | ||
""" | ||
Base class of config classes. | ||
Subclass may override `_canonical_rules` and `_validation_rules`, | ||
and `validate()` if the logic is complex. | ||
""" | ||
|
||
# Rules to convert field value to canonical format. | ||
# The key is field name. | ||
# The value is callable `value -> canonical_value` | ||
# It is not type-hinted so dataclass won't treat it as field | ||
_canonical_rules = {} # type: ignore | ||
|
||
# Rules to validate field value. | ||
# The key is field name. | ||
# The value is callable `value -> valid` or `value -> (valid, error_message)` | ||
# The rule will be called with canonical format and is only called when `value` is not None. | ||
# `error_message` is used when `valid` is False. | ||
# It will be prepended with class name and field name in exception message. | ||
_validation_rules = {} # type: ignore | ||
|
||
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs): | ||
""" | ||
Initialize a config object and set some fields. | ||
Name of keyword arguments can either be snake_case or camelCase. | ||
They will be converted to snake_case automatically. | ||
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`. | ||
""" | ||
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()} | ||
if _base_path is None: | ||
_base_path = Path() | ||
for field in dataclasses.fields(self): | ||
value = kwargs.pop(util.case_insensitive(field.name), field.default) | ||
if value is not None and not _is_missing(value): | ||
# relative paths loaded from config file are not relative to pwd | ||
if 'Path' in str(field.type): | ||
value = Path(value).expanduser() | ||
if not value.is_absolute(): | ||
value = _base_path / value | ||
# convert nested dict to config type | ||
if isinstance(value, dict): | ||
cls = util.strip_optional(field.type) | ||
if isinstance(cls, type) and issubclass(cls, ConfigBase): | ||
value = cls(**value, _base_path=_base_path) | ||
setattr(self, field.name, value) | ||
if kwargs: | ||
cls = type(self).__name__ | ||
fields = ', '.join(kwargs.keys()) | ||
raise ValueError(f'{cls}: Unrecognized fields {fields}') | ||
|
||
@classmethod | ||
def load(cls: Type[T], path: PathLike) -> T: | ||
""" | ||
Load config from YAML (or JSON) file. | ||
Keys in YAML file can either be camelCase or snake_case. | ||
""" | ||
data = yaml.safe_load(open(path)) | ||
if not isinstance(data, dict): | ||
raise ValueError(f'Content of config file {path} is not a dict/object') | ||
return cls(**data, _base_path=Path(path).parent) | ||
|
||
def json(self) -> Dict[str, Any]: | ||
""" | ||
Convert config to JSON object. | ||
The keys of returned object will be camelCase. | ||
""" | ||
return dataclasses.asdict( | ||
self.canonical(), | ||
dict_factory = lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None) | ||
) | ||
|
||
def canonical(self: T) -> T: | ||
""" | ||
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format. | ||
Noticeably, relative path may be converted to absolute path. | ||
""" | ||
ret = copy.deepcopy(self) | ||
for field in dataclasses.fields(ret): | ||
key, value = field.name, getattr(ret, field.name) | ||
rule = ret._canonical_rules.get(key) | ||
if rule is not None: | ||
setattr(ret, key, rule(value)) | ||
elif isinstance(value, ConfigBase): | ||
setattr(ret, key, value.canonical()) | ||
# value will be copied twice, should not be a performance issue anyway | ||
return ret | ||
|
||
def validate(self) -> None: | ||
""" | ||
Validate the config object and raise Exception if it's ill-formed. | ||
""" | ||
class_name = type(self).__name__ | ||
config = self.canonical() | ||
|
||
for field in dataclasses.fields(config): | ||
key, value = field.name, getattr(config, field.name) | ||
|
||
# check existence | ||
if _is_missing(value): | ||
raise ValueError(f'{class_name}: {key} is not set') | ||
|
||
# check type (TODO) | ||
type_name = str(field.type).replace('typing.', '') | ||
optional = any([ | ||
type_name.startswith('Optional['), | ||
type_name.startswith('Union[') and 'NoneType' in type_name, | ||
type_name == 'Any' | ||
]) | ||
if value is None: | ||
if optional: | ||
continue | ||
else: | ||
raise ValueError(f'{class_name}: {key} cannot be None') | ||
|
||
# check value | ||
rule = config._validation_rules.get(key) | ||
if rule is not None: | ||
try: | ||
result = rule(value) | ||
except Exception: | ||
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}') | ||
|
||
if isinstance(result, bool): | ||
if not result: | ||
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range') | ||
else: | ||
if not result[0]: | ||
raise ValueError(f'{class_name}: {key} {result[1]}') | ||
|
||
# check nested config | ||
if isinstance(value, ConfigBase): | ||
value.validate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from .base import ConfigBase, PathLike | ||
from . import util | ||
|
||
__all__ = [ | ||
'ExperimentConfig', | ||
'AlgorithmConfig', | ||
'CustomAlgorithmConfig', | ||
'TrainingServiceConfig', | ||
] | ||
|
||
|
||
@dataclass(init=False) | ||
class _AlgorithmConfig(ConfigBase): | ||
name: Optional[str] = None | ||
class_name: Optional[str] = None | ||
code_directory: Optional[PathLike] = None | ||
class_args: Optional[Dict[str, Any]] = None | ||
|
||
def validate(self): | ||
super().validate() | ||
_validate_algo(self) | ||
|
||
|
||
@dataclass(init=False) | ||
class AlgorithmConfig(_AlgorithmConfig): | ||
name: str | ||
class_args: Optional[Dict[str, Any]] = None | ||
|
||
|
||
@dataclass(init=False) | ||
class CustomAlgorithmConfig(_AlgorithmConfig): | ||
class_name: str | ||
class_directory: Optional[PathLike] = None | ||
class_args: Optional[Dict[str, Any]] = None | ||
|
||
|
||
class TrainingServiceConfig(ConfigBase): | ||
platform: str | ||
|
||
|
||
@dataclass(init=False) | ||
class ExperimentConfig(ConfigBase): | ||
experiment_name: Optional[str] = None | ||
search_space_file: Optional[PathLike] = None | ||
search_space: Any = None | ||
trial_command: str | ||
trial_code_directory: PathLike = '.' | ||
trial_concurrency: int | ||
trial_gpu_number: int = 0 | ||
max_experiment_duration: Optional[str] = None | ||
max_trial_number: Optional[int] = None | ||
nni_manager_ip: Optional[str] = None | ||
use_annotation: bool = False | ||
debug: bool = False | ||
log_level: Optional[str] = None | ||
experiment_working_directory: Optional[PathLike] = None | ||
tuner_gpu_indices: Optional[Union[List[int], str]] = None | ||
tuner: Optional[_AlgorithmConfig] = None | ||
accessor: Optional[_AlgorithmConfig] = None | ||
advisor: Optional[_AlgorithmConfig] = None | ||
training_service: TrainingServiceConfig | ||
|
||
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): | ||
super().__init__(**kwargs) | ||
if training_service_platform is not None: | ||
assert 'training_service' not in kwargs | ||
self.training_service = util.training_service_config_factory(training_service_platform) | ||
|
||
def validate(self, initialized_tuner: bool = False) -> None: | ||
super().validate() | ||
if initialized_tuner: | ||
_validate_for_exp(self) | ||
else: | ||
_validate_for_nnictl(self) | ||
|
||
## End of public API ## | ||
|
||
@property | ||
def _canonical_rules(self): | ||
return _canonical_rules | ||
|
||
@property | ||
def _validation_rules(self): | ||
return _validation_rules | ||
|
||
|
||
_canonical_rules = { | ||
'search_space_file': util.canonical_path, | ||
'trial_code_directory': util.canonical_path, | ||
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None, | ||
'experiment_working_directory': util.canonical_path, | ||
'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value | ||
} | ||
|
||
_validation_rules = { | ||
'search_space_file': lambda value: (Path(value).is_file(), f'"{value}" does not exist or is not regular file'), | ||
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'), | ||
'trial_concurrency': lambda value: value > 0, | ||
'trial_gpu_number': lambda value: value >= 0, | ||
'max_experiment_duration': lambda value: util.parse_time(value) > 0, | ||
'max_trial_number': lambda value: value > 0, | ||
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"], | ||
'tuner_gpu_indices': lambda value: all(i >= 0 for i in value) and len(value) == len(set(value)), | ||
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') | ||
} | ||
|
||
def _validate_for_exp(config: ExperimentConfig) -> None: | ||
# validate experiment for nni.Experiment, where tuner is already initialized outside | ||
if config.use_annotation: | ||
raise ValueError('ExperimentConfig: annotation is not supported in this mode') | ||
if util.count(config.search_space, config.search_space_file) != 1: | ||
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one') | ||
if util.count(config.tuner, config.accessor, config.advisor) != 0: | ||
raise ValueError('ExperimentConfig: tuner, accessor, and advisor must not be set in for this mode') | ||
if config.tuner_gpu_indices is not None: | ||
raise ValueError('ExperimentConfig: tuner_gpu_indices is not supported in this mode') | ||
|
||
def _validate_for_nnictl(config: ExperimentConfig) -> None: | ||
# validate experiment for normal launching approach | ||
if config.use_annotation: | ||
if util.count(config.search_space, config.search_space_file) != 0: | ||
raise ValueError('ExperimentConfig: search_space and search_space_file must not be set with annotationn') | ||
else: | ||
if util.count(config.search_space, config.search_space_file) != 1: | ||
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one') | ||
if util.count(config.tuner, config.advisor) != 1: | ||
raise ValueError('ExperimentConfig: tuner and advisor must be set one') | ||
|
||
def _validate_algo(algo: AlgorithmConfig) -> None: | ||
if algo.name is None: | ||
if algo.class_name is None: | ||
raise ValueError('Missing algorithm name') | ||
if algo.code_directory is not None and not Path(algo.code_directory).is_dir(): | ||
raise ValueError(f'code_directory "{algo.code_directory}" does not exist or is not directory') | ||
else: | ||
if algo.class_name is not None or algo.code_directory is not None: | ||
raise ValueError(f'When name is set for registered algorithm, class_name and code_directory cannot be used') | ||
# TODO: verify algorithm installation and class args |
Oops, something went wrong.