Skip to content

Commit

Permalink
Allow postponed annotations for config classes (microsoft#4883)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Jun 29, 2022
1 parent 3d6ddb9 commit 00e4deb
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 41 deletions.
10 changes: 5 additions & 5 deletions nni/experiment/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, **kwargs):
"""
self._base_path = utils.get_base_path()
args = {utils.case_insensitive(key): value for key, value in kwargs.items()}
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = args.pop(utils.case_insensitive(field.name), field.default)
setattr(self, field.name, value)
if args: # maybe a key is misspelled
Expand All @@ -98,7 +98,7 @@ def __init__(self, **kwargs):
raise AttributeError(f'{class_name} does not have field(s) {fields}')

# try to unpack nested config
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
if utils.is_instance(value, field.type):
continue # already accepted by subclass, don't touch it
Expand Down Expand Up @@ -214,7 +214,7 @@ def _canonicalize(self, parents):
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
"""
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
if isinstance(value, (Path, str)) and utils.is_path_like(field.type):
setattr(self, field.name, utils.resolve_path(value, self._base_path))
Expand All @@ -235,7 +235,7 @@ def _validate_canonical(self):
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
"""
utils.validate_type(self)
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
_recursive_validate_child(value)

Expand All @@ -247,7 +247,7 @@ def __setattr__(self, name, value):
if hasattr(self, name) or name.startswith('_'):
super().__setattr__(name, value)
return
if name in [field.name for field in dataclasses.fields(self)]: # might happend during __init__
if name in [field.name for field in utils.fields(self)]: # might happend during __init__
super().__setattr__(name, value)
return
raise AttributeError(f'{type(self).__name__} does not have field {name}')
Expand Down
1 change: 0 additions & 1 deletion nni/experiment/config/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def _canonicalize(self, parents):
def _validate_canonical(self):
super()._validate_canonical()
cls = type(self)
assert self.platform == cls.platform
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'{cls.__name__}: trial_code_directory "{self.trial_code_directory}" is not a directory')
assert self.trial_gpu_number is None or self.trial_gpu_number >= 0
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from dataclasses import dataclass

from typing_extensions import Literal

from ..training_service import TrainingServiceConfig

@dataclass(init=False)
class AmlConfig(TrainingServiceConfig):
platform: str = 'aml'
platform: Literal['aml'] = 'aml'
subscription_id: str
resource_group: str
workspace_name: str
Expand Down
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/dlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from dataclasses import dataclass
from typing import Optional

from typing_extensions import Literal

from ..training_service import TrainingServiceConfig

__all__ = ['DlcConfig']

@dataclass(init=False)
class DlcConfig(TrainingServiceConfig):
platform: str = 'dlc'
platform: Literal['dlc'] = 'dlc'
type: str = 'Worker'
image: str # 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
job_type: str = 'TFJob'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from dataclasses import dataclass
from typing import List, Optional, Union

from typing_extensions import Literal

from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
Expand All @@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):

@dataclass(init=False)
class FrameworkControllerConfig(TrainingServiceConfig):
platform: str = 'frameworkcontroller'
platform: Literal['frameworkcontroller'] = 'frameworkcontroller'
storage: K8sStorageConfig
service_account_name: Optional[str]
task_roles: List[FrameworkControllerRoleConfig]
Expand Down
6 changes: 4 additions & 2 deletions nni/experiment/config/training_services/k8s_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Optional

from typing_extensions import Literal

from ..base import ConfigBase

@dataclass(init=False)
Expand All @@ -34,13 +36,13 @@ def _validate_canonical(self):

@dataclass(init=False)
class K8sNfsConfig(K8sStorageConfig):
storage: str = 'nfs'
storage: Literal['nfs'] = 'nfs'
server: str
path: str

@dataclass(init=False)
class K8sAzureStorageConfig(K8sStorageConfig):
storage: str = 'azureStorage'
storage: Literal['azureStorage'] = 'azureStorage'
azure_account: str
azure_share: str
key_vault_name: str
Expand Down
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/kubeflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from dataclasses import dataclass
from typing import Optional, Union

from typing_extensions import Literal

from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
Expand All @@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):

@dataclass(init=False)
class KubeflowConfig(TrainingServiceConfig):
platform: str = 'kubeflow'
platform: Literal['kubeflow'] = 'kubeflow'
operator: str
api_version: str
storage: K8sStorageConfig
Expand Down
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from dataclasses import dataclass
from typing import List, Optional, Union

from typing_extensions import Literal

from ..training_service import TrainingServiceConfig
from .. import utils

@dataclass(init=False)
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
platform: Literal['local'] = 'local'
use_active_gpu: Optional[bool] = None
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], int, str, None] = None
Expand Down
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/openpai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from pathlib import Path
from typing import Dict, Optional, Union

from typing_extensions import Literal

from ..training_service import TrainingServiceConfig
from ..utils import PathLike

@dataclass(init=False)
class OpenpaiConfig(TrainingServiceConfig):
platform: str = 'openpai'
platform: Literal['openpai'] = 'openpai'
host: str
username: str
token: str
Expand Down
4 changes: 3 additions & 1 deletion nni/experiment/config/training_services/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import List, Optional, Union
import warnings

from typing_extensions import Literal

from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .. import utils
Expand Down Expand Up @@ -60,7 +62,7 @@ def _validate_canonical(self):

@dataclass(init=False)
class RemoteConfig(TrainingServiceConfig):
platform: str = 'remote'
platform: Literal['remote'] = 'remote'
machine_list: List[RemoteMachineConfig]
reuse_mode: bool = True

Expand Down
68 changes: 42 additions & 26 deletions nni/experiment/config/utils/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,56 @@
If you are implementing a config class for a training service, it's unlikely you will need these.
"""

from __future__ import annotations

__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'fields', 'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]

import copy
import dataclasses
import importlib
import json
import os.path
from pathlib import Path
import socket
import typing

import typeguard

import nni.runtime.config

from .public import is_missing

__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
if typing.TYPE_CHECKING:
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig

## handle relative path ##

_current_base_path = None
_current_base_path: Path | None = None

def get_base_path():
def get_base_path() -> Path:
if _current_base_path is None:
return Path()
return _current_base_path

def set_base_path(path):
def set_base_path(path: Path) -> None:
global _current_base_path
assert _current_base_path is None
_current_base_path = path

def unset_base_path():
def unset_base_path() -> None:
global _current_base_path
_current_base_path = None

def resolve_path(path, base_path):
if path is None:
return None
def resolve_path(path: Path | str, base_path: Path) -> str:
assert path is not None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path = os.path.expanduser(path)
if not os.path.isabs(path):
Expand All @@ -58,23 +65,32 @@ def resolve_path(path, base_path):

## field name case convertion ##

def case_insensitive(key):
def case_insensitive(key: str) -> str:
return key.lower().replace('_', '')

def camel_case(key):
def camel_case(key: str) -> str:
words = key.strip('_').split('_')
return words[0] + ''.join(word.title() for word in words[1:])

## type hint utils ##

def is_instance(value, type_hint):
def fields(config: ConfigBase) -> list[dataclasses.Field]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret = [copy.copy(field) for field in dataclasses.fields(config)]
types = typing.get_type_hints(type(config))
for field in ret:
field.type = types[field.name]
return ret

def is_instance(value, type_hint) -> bool:
try:
typeguard.check_type('_', value, type_hint)
except TypeError:
return False
return True

def validate_type(config):
def validate_type(config: ConfigBase) -> None:
class_name = type(config).__name__
for field in dataclasses.fields(config):
value = getattr(config, field.name)
Expand All @@ -84,17 +100,17 @@ def validate_type(config):
if not is_instance(value, field.type):
raise ValueError(f'{class_name}: type of {field.name} ({repr(value)}) is not {field.type}')

def is_path_like(type_hint):
def is_path_like(type_hint) -> bool:
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return is_instance(Path(), type_hint) and not is_instance(1, type_hint)

## type inference ##

def guess_config_type(obj, type_hint):
def guess_config_type(obj, type_hint) -> ConfigBase | None:
ret = guess_list_config_type([obj], type_hint, _hint_list_item=True)
return ret[0] if ret else None

def guess_list_config_type(objs, type_hint, _hint_list_item=False):
def guess_list_config_type(objs, type_hint, _hint_list_item=False) -> list[ConfigBase] | None:
# avoid circular import
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
Expand Down Expand Up @@ -144,20 +160,20 @@ def _all_subclasses(cls):
subclasses = set(cls.__subclasses__())
return subclasses.union(*[_all_subclasses(subclass) for subclass in subclasses])

def training_service_config_factory(platform):
def training_service_config_factory(platform: str) -> TrainingServiceConfig:
cls = _get_ts_config_class(platform)
if cls is None:
raise ValueError(f'Bad training service platform: {platform}')
return cls()

def load_training_service_config(config):
def load_training_service_config(config) -> TrainingServiceConfig:
if isinstance(config, dict) and 'platform' in config:
cls = _get_ts_config_class(config['platform'])
if cls is not None:
return cls(**config)
return config # not valid json, don't touch

def _get_ts_config_class(platform):
def _get_ts_config_class(platform: str) -> type[TrainingServiceConfig] | None:
from ..training_service import TrainingServiceConfig # avoid circular import

# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
Expand All @@ -175,7 +191,7 @@ def _get_ts_config_class(platform):

## misc ##

def get_ipv4_address():
def get_ipv4_address() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('192.0.2.0', 80))
addr = s.getsockname()[0]
Expand Down

0 comments on commit 00e4deb

Please sign in to comment.