Skip to content

Commit

Permalink
[Retiarii] end2end (microsoft#3122)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Dec 11, 2020
1 parent 7d1acfb commit d165905
Show file tree
Hide file tree
Showing 54 changed files with 1,391 additions and 610 deletions.
26 changes: 26 additions & 0 deletions examples/trials/mnist-tfv2/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# FIXME: For demonstration only. It should not be here

from pathlib import Path

from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner

tuner = HyperoptTuner('tpe')

search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] },
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] },
"batch_size": { "_type": "choice", "_value": [16, 32] },
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
}

experiment = Experiment(tuner, 'local')
experiment.config.experiment_name = 'test'
experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5
experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent

experiment.run(8081, debug=True)
3 changes: 3 additions & 0 deletions nni/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

__version__ = '999.0.0-developing'

from .runtime.log import init_logger
init_logger()

from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator

Expand Down
3 changes: 3 additions & 0 deletions nni/experiment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .config import *
from .experiment import Experiment, RetiariiExperiment

from .nni_client import *
3 changes: 3 additions & 0 deletions nni/experiment/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import ExperimentConfig, RetiariiExpConfig

from .local import LocalExperimentConfig
115 changes: 115 additions & 0 deletions nni/experiment/config/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import dataclasses
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union


@dataclasses.dataclass(init=False)
class ExperimentConfig:
experiment_name: str
search_space: Any
max_execution_seconds: Optional[int] = None
max_trial_number: Optional[int] = None
trial_concurrency: int
trial_command: str
trial_code_directory: Union[Path, str]
trial_gpu_number: int = 0
extra_config: Optional[Dict[str, str]] = None

_training_service: str


# these values will be used to create template object,
# and the user should overwrite them later.
_placeholder = {
'experiment_name': '_unset_',
'search_space': '_unset_',
'trial_concurrency': -1,
'trial_command': '_unset_',
'trial_code_directory': '_unset_'
}

# simple validation functions
# complex validation logic with special error message should go to `validate()` method instead
_value_range = {
'max_execution_seconds': lambda x: x is None or x > 0,
'max_trial_number': lambda x: x is None or x > 0,
'trial_concurrency': lambda x: x > 0,
'trial_gpu_number': lambda x: x >= 0
}


def __init__(self, **kwargs):
for field in dataclasses.fields(self):
if field.name in kwargs:
setattr(self, field.name, kwargs[field.name])
elif field.default != dataclasses.MISSING:
setattr(self, field.name, field.default)
else:
setattr(self, field.name, type(self)._placeholder[field.name])


def validate(self) -> None:
# check existence
for key, placeholder_value in type(self)._placeholder.items():
if getattr(self, key) == placeholder_value:
raise ValueError(f'Field "{key}" is not set')

# TODO: check type

# check value
for key, condition in type(self)._value_range.items():
value = getattr(self, key)
if not condition(value):
raise ValueError(f'Field "{key}" ({repr(value)}) out of range')

# check special fields
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'Trial code directory "{self.trial_code_directory}" does not exist or is not directory')


def experiment_config_json(self) -> Dict[str, Any]:
# this only contains the common part for most (if not all) training services
# subclasses should override it to provide exclusive fields
return {
'authorName': '_',
'experimentName': self.experiment_name,
'trialConcurrency': self.trial_concurrency,
'maxExecDuration': self.max_execution_seconds or (999 * 24 * 3600),
'maxTrialNum': self.max_trial_number or 99999,
'searchSpace': json.dumps(self.search_space),
'trainingServicePlatform': self._training_service,
'tuner': {'builtinTunerName': '_user_created_'},
**(self.extra_config or {})
}

def cluster_metadata_json(self) -> Any:
# the cluster metadata format is a total mess
# leave it to each subclass before we refactoring nni manager
raise NotImplementedError()


@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
return cls()
raise ValueError(f'Unrecognized training service {training_service}')


class RetiariiExpConfig(ExperimentConfig):
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
config_obj = cls()
config_obj.search_space = {}
config_obj.trial_command = 'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj.trial_code_directory = '../..'
return config_obj
40 changes: 40 additions & 0 deletions nni/experiment/config/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict

from .base import ExperimentConfig


@dataclass(init=False)
class LocalExperimentConfig(ExperimentConfig):
use_active_gpu: bool = False

_training_service: str = 'local'

def experiment_config_json(self) -> Dict[str, Any]:
ret = super().experiment_config_json()
ret['clusterMetaData'] = [
{
'key': 'codeDir',
'value': str(Path(self.trial_code_directory).resolve())
},
{
'key': 'command',
'value': self.trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return ret

def cluster_metadata_json(self) -> Any:
return {
'trial_config': {
'command': self.trial_command,
'codeDir': str(Path(self.trial_code_directory).resolve())
}
}
Loading

0 comments on commit d165905

Please sign in to comment.