forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7d1acfb
commit d165905
Showing
54 changed files
with
1,391 additions
and
610 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# FIXME: For demonstration only. It should not be here | ||
|
||
from pathlib import Path | ||
|
||
from nni.experiment import Experiment | ||
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner | ||
|
||
tuner = HyperoptTuner('tpe') | ||
|
||
search_space = { | ||
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] }, | ||
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] }, | ||
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] }, | ||
"batch_size": { "_type": "choice", "_value": [16, 32] }, | ||
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] } | ||
} | ||
|
||
experiment = Experiment(tuner, 'local') | ||
experiment.config.experiment_name = 'test' | ||
experiment.config.trial_concurrency = 2 | ||
experiment.config.max_trial_number = 5 | ||
experiment.config.search_space = search_space | ||
experiment.config.trial_command = 'python3 mnist.py' | ||
experiment.config.trial_code_directory = Path(__file__).parent | ||
|
||
experiment.run(8081, debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .config import * | ||
from .experiment import Experiment, RetiariiExperiment | ||
|
||
from .nni_client import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import ExperimentConfig, RetiariiExpConfig | ||
|
||
from .local import LocalExperimentConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import dataclasses | ||
import json | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Union | ||
|
||
|
||
@dataclasses.dataclass(init=False) | ||
class ExperimentConfig: | ||
experiment_name: str | ||
search_space: Any | ||
max_execution_seconds: Optional[int] = None | ||
max_trial_number: Optional[int] = None | ||
trial_concurrency: int | ||
trial_command: str | ||
trial_code_directory: Union[Path, str] | ||
trial_gpu_number: int = 0 | ||
extra_config: Optional[Dict[str, str]] = None | ||
|
||
_training_service: str | ||
|
||
|
||
# these values will be used to create template object, | ||
# and the user should overwrite them later. | ||
_placeholder = { | ||
'experiment_name': '_unset_', | ||
'search_space': '_unset_', | ||
'trial_concurrency': -1, | ||
'trial_command': '_unset_', | ||
'trial_code_directory': '_unset_' | ||
} | ||
|
||
# simple validation functions | ||
# complex validation logic with special error message should go to `validate()` method instead | ||
_value_range = { | ||
'max_execution_seconds': lambda x: x is None or x > 0, | ||
'max_trial_number': lambda x: x is None or x > 0, | ||
'trial_concurrency': lambda x: x > 0, | ||
'trial_gpu_number': lambda x: x >= 0 | ||
} | ||
|
||
|
||
def __init__(self, **kwargs): | ||
for field in dataclasses.fields(self): | ||
if field.name in kwargs: | ||
setattr(self, field.name, kwargs[field.name]) | ||
elif field.default != dataclasses.MISSING: | ||
setattr(self, field.name, field.default) | ||
else: | ||
setattr(self, field.name, type(self)._placeholder[field.name]) | ||
|
||
|
||
def validate(self) -> None: | ||
# check existence | ||
for key, placeholder_value in type(self)._placeholder.items(): | ||
if getattr(self, key) == placeholder_value: | ||
raise ValueError(f'Field "{key}" is not set') | ||
|
||
# TODO: check type | ||
|
||
# check value | ||
for key, condition in type(self)._value_range.items(): | ||
value = getattr(self, key) | ||
if not condition(value): | ||
raise ValueError(f'Field "{key}" ({repr(value)}) out of range') | ||
|
||
# check special fields | ||
if not Path(self.trial_code_directory).is_dir(): | ||
raise ValueError(f'Trial code directory "{self.trial_code_directory}" does not exist or is not directory') | ||
|
||
|
||
def experiment_config_json(self) -> Dict[str, Any]: | ||
# this only contains the common part for most (if not all) training services | ||
# subclasses should override it to provide exclusive fields | ||
return { | ||
'authorName': '_', | ||
'experimentName': self.experiment_name, | ||
'trialConcurrency': self.trial_concurrency, | ||
'maxExecDuration': self.max_execution_seconds or (999 * 24 * 3600), | ||
'maxTrialNum': self.max_trial_number or 99999, | ||
'searchSpace': json.dumps(self.search_space), | ||
'trainingServicePlatform': self._training_service, | ||
'tuner': {'builtinTunerName': '_user_created_'}, | ||
**(self.extra_config or {}) | ||
} | ||
|
||
def cluster_metadata_json(self) -> Any: | ||
# the cluster metadata format is a total mess | ||
# leave it to each subclass before we refactoring nni manager | ||
raise NotImplementedError() | ||
|
||
|
||
@staticmethod | ||
def create_template(training_service: str) -> 'ExperimentConfig': | ||
for cls in ExperimentConfig.__subclasses__(): | ||
for field in dataclasses.fields(cls): | ||
if field.name == '_training_service' and field.default == training_service: | ||
return cls() | ||
raise ValueError(f'Unrecognized training service {training_service}') | ||
|
||
|
||
class RetiariiExpConfig(ExperimentConfig): | ||
@staticmethod | ||
def create_template(training_service: str) -> 'ExperimentConfig': | ||
for cls in ExperimentConfig.__subclasses__(): | ||
for field in dataclasses.fields(cls): | ||
if field.name == '_training_service' and field.default == training_service: | ||
config_obj = cls() | ||
config_obj.search_space = {} | ||
config_obj.trial_command = 'python3 -m nni.retiarii.trial_entry' | ||
# FIXME: expose this field to users | ||
config_obj.trial_code_directory = '../..' | ||
return config_obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any, Dict | ||
|
||
from .base import ExperimentConfig | ||
|
||
|
||
@dataclass(init=False) | ||
class LocalExperimentConfig(ExperimentConfig): | ||
use_active_gpu: bool = False | ||
|
||
_training_service: str = 'local' | ||
|
||
def experiment_config_json(self) -> Dict[str, Any]: | ||
ret = super().experiment_config_json() | ||
ret['clusterMetaData'] = [ | ||
{ | ||
'key': 'codeDir', | ||
'value': str(Path(self.trial_code_directory).resolve()) | ||
}, | ||
{ | ||
'key': 'command', | ||
'value': self.trial_command | ||
} | ||
] | ||
#ret['local_config'] = { | ||
# 'useActiveGpu': self.use_active_gpu | ||
#} | ||
return ret | ||
|
||
def cluster_metadata_json(self) -> Any: | ||
return { | ||
'trial_config': { | ||
'command': self.trial_command, | ||
'codeDir': str(Path(self.trial_code_directory).resolve()) | ||
} | ||
} |
Oops, something went wrong.