forked from microsoft/nni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Retiarii] Base execution engine, codegen and trainer (microsoft#3059)
- Loading branch information
Showing
26 changed files
with
1,090 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .execution import * | ||
from .graph import * | ||
from .mutator import * | ||
from .operation import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pytorch import model_to_pytorch_script |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import * | ||
|
||
from ..graph import IllegalGraphError, Edge, Graph, Node, Model | ||
from ..operation import Operation, Cell | ||
|
||
|
||
def model_to_pytorch_script(model: Model) -> str: | ||
graphs = [graph_to_pytorch_model(name, cell) for name, cell in model.graphs.items()] | ||
return _PyTorchScriptTemplate.format('\n\n'.join(graphs)).strip() | ||
|
||
|
||
def _sorted_incoming_edges(node: Node) -> List[Edge]: | ||
edges = [edge for edge in node.graph.edges if edge.tail is node] | ||
if not edges: | ||
return [] | ||
if all(edge.tail_slot is None for edge in edges): | ||
return edges | ||
if all(isinstance(edge.tail_slot, int) for edge in edges): | ||
edges = sorted(edges, key=(lambda edge: edge.tail_slot)) | ||
if [edge.tail_slot for edge in edges] == list(range(len(edges))): | ||
return edges | ||
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) | ||
|
||
|
||
def _format_inputs(node: Node) -> str: | ||
edges = _sorted_incoming_edges(node) | ||
inputs = [] | ||
for edge in edges: | ||
if edge.head.name == '_inputs': | ||
assert isinstance(edge.head_slot, int) | ||
if node.graph.input_names is not None: | ||
# when input has names, e.g., forward(self, tensor1, tensor2, another_one) | ||
inputs.append(node.graph.input_names[edge.head_slot]) | ||
else: | ||
# when input has no name, e.g., forward(*_inputs) | ||
inputs.append('_inputs[{}]'.format(edge.head_slot)) | ||
else: | ||
if edge.head_slot is None: | ||
# when the input comes from a single-output operator | ||
inputs.append('{}'.format(edge.head.name)) | ||
else: | ||
# when the input comes from a multi-output operator: needs to know which one it comes from | ||
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) | ||
return ', '.join(inputs) | ||
|
||
|
||
def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str: | ||
nodes = graph.nodes # FIXME: topological sort is needed here | ||
|
||
# handle module node and function node differently | ||
# only need to generate code for module here | ||
node_codes = [] | ||
for node in nodes: | ||
if node.operation: | ||
node_codes.append(node.operation.to_init_code(node.name)) | ||
|
||
if graph.input_names is None: | ||
input_code = '*_inputs' | ||
else: | ||
input_code = ', '.join(graph.input_names) | ||
|
||
edge_codes = [] | ||
|
||
for node in nodes: | ||
if node.operation: | ||
inputs = _format_inputs(node) | ||
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs)) | ||
|
||
output_code = _format_inputs(graph.output_node) | ||
if not output_code: | ||
output_code = 'None' | ||
|
||
linebreak = '\n ' | ||
return _PyTorchModelTemplate.format( | ||
graph_name=('Graph' if graph_name == '_graph' else graph_name), | ||
inputs=input_code, | ||
outputs=output_code, | ||
nodes=linebreak.join(node_codes), | ||
edges=linebreak.join(edge_codes) | ||
) | ||
|
||
|
||
# TODO: handle imports | ||
|
||
_PyTorchScriptTemplate = ''' | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
{} | ||
''' | ||
|
||
_PyTorchModelTemplate = ''' | ||
class {graph_name}(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
{nodes} | ||
def forward(self, {inputs}): | ||
{edges} | ||
return {outputs} | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# pylint: skip-file | ||
|
||
""" | ||
FIXME | ||
This file is inherited from last version. | ||
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't | ||
been tested and I'm not sure. | ||
""" | ||
|
||
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node | ||
from ..operations_tf import Operation | ||
from ..type_utils import * | ||
|
||
|
||
def graph_to_tensorflow_script(graph: Graph) -> str: | ||
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()] | ||
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip() | ||
|
||
|
||
def _sort_incoming_edges(node: Node) -> List[Edge]: | ||
edges = [edge for edge in node.graph.edges if edge.tail is node] | ||
if not edges: | ||
return [] | ||
if all(edge.tail_idx is None for edge in edges): | ||
return edges | ||
if all(isinstance(edge.tail_idx, int) for edge in edges): | ||
edges = sorted(edges, key=(lambda edge: edge.tail_idx)) | ||
if [edge.tail_idx for edge in edges] == list(range(len(edges))): | ||
return edges | ||
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) | ||
|
||
def _format_inputs(node: Node) -> str: | ||
edges = _sort_incoming_edges(node) | ||
inputs = [] | ||
for edge in edges: | ||
if edge.head.name == '_inputs': | ||
assert isinstance(edge.head_idx, int) | ||
if node.graph.input_names is not None: | ||
inputs.append(node.graph.input_names[edge.head_idx]) | ||
else: | ||
inputs.append('_inputs[{}]'.format(edge.head_idx)) | ||
else: | ||
if edge.head_idx is None: | ||
inputs.append('{}'.format(edge.head.name)) | ||
else: | ||
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx)) | ||
return ', '.join(inputs) | ||
|
||
|
||
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str: | ||
nodes = graph.topo_sort() | ||
|
||
# handle module node and function node differently | ||
# only need to generate code for module here | ||
node_codes = [] | ||
for node in nodes: | ||
if isinstance(node, Cell): | ||
node_codes.append('self.{} = {}()'.format(node.name, node.template_name)) | ||
else: | ||
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init())) | ||
|
||
edge_codes = [] | ||
|
||
for node in nodes: | ||
inputs = _format_inputs(node) | ||
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs)) | ||
|
||
output_code = _format_inputs(graph.output_node) | ||
if not output_code: | ||
output_code = 'None' | ||
|
||
if graph.input_names is None: | ||
input_code = '*_inputs' | ||
else: | ||
input_code = ', '.join(graph.input_names) | ||
|
||
linebreak = '\n ' | ||
return _TensorFlowModelTemplate.format( | ||
graph_name=('Graph' if graph_name == '_graph' else graph_name), | ||
inputs=input_code, | ||
outputs=output_code, | ||
nodes=linebreak.join(node_codes), | ||
edges=linebreak.join(edge_codes) | ||
) | ||
|
||
|
||
_TensorFlowScriptTemplate = ''' | ||
import tensorflow as tf | ||
import tensorflow.keras as K | ||
import sdk.custom_ops_tf as CUSTOM | ||
{} | ||
''' | ||
|
||
_TensorFlowModelTemplate = ''' | ||
class {graph_name}(K.Model): | ||
def __init__(self): | ||
super().__init__() | ||
{nodes} | ||
def call(self, {inputs}): | ||
{edges} | ||
return {outputs} | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .api import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import time | ||
from typing import * | ||
|
||
from ..graph import Model, ModelStatus | ||
from .base import BaseExecutionEngine | ||
from .interface import * | ||
from .listener import DefaultListener | ||
|
||
_execution_engine = None | ||
_default_listener = None | ||
|
||
__all__ = ['get_execution_engine', 'get_and_register_default_listener', | ||
'submit_models', 'wait_models', 'query_available_resources'] | ||
|
||
|
||
def get_execution_engine() -> BaseExecutionEngine: | ||
""" | ||
Currently we assume the default execution engine is BaseExecutionEngine. | ||
""" | ||
global _execution_engine | ||
if _execution_engine is None: | ||
_execution_engine = BaseExecutionEngine() | ||
return _execution_engine | ||
|
||
|
||
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener: | ||
global _default_listener | ||
if _default_listener is None: | ||
_default_listener = DefaultListener() | ||
engine.register_graph_listener(_default_listener) | ||
return _default_listener | ||
|
||
|
||
def submit_models(*models: Model) -> None: | ||
engine = get_execution_engine() | ||
get_and_register_default_listener(engine) | ||
engine.submit_models(*models) | ||
|
||
|
||
def wait_models(*models: Model) -> None: | ||
get_and_register_default_listener(get_execution_engine()) | ||
while True: | ||
time.sleep(1) | ||
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)] | ||
if not left_models: | ||
break | ||
|
||
|
||
def query_available_resources() -> List[WorkerInfo]: | ||
listener = get_and_register_default_listener(get_execution_engine()) | ||
return listener.resources |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import * | ||
|
||
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo | ||
from .. import codegen, utils | ||
from ..graph import Model, ModelStatus, MetricData | ||
from ..integration import send_trial, receive_trial_parameters, get_advisor | ||
|
||
|
||
class BaseGraphData: | ||
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None: | ||
self.model_script = model_script | ||
self.training_module = training_module | ||
self.training_kwargs = training_kwargs | ||
|
||
def dump(self) -> dict: | ||
return { | ||
'model_script': self.model_script, | ||
'training_module': self.training_module, | ||
'training_kwargs': self.training_kwargs | ||
} | ||
|
||
@staticmethod | ||
def load(data): | ||
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs']) | ||
|
||
|
||
class BaseExecutionEngine(AbstractExecutionEngine): | ||
""" | ||
The execution engine with no optimization at all. | ||
Resource management is yet to be implemented. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
""" | ||
Upon initialization, advisor callbacks need to be registered. | ||
Advisor will call the callbacks when the corresponding event has been triggered. | ||
Base execution engine will get those callbacks and broadcast them to graph listener. | ||
""" | ||
self._listeners: List[AbstractGraphListener] = [] | ||
|
||
# register advisor callbacks | ||
advisor = get_advisor() | ||
advisor.send_trial_callback = self._send_trial_callback | ||
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback | ||
advisor.trial_end_callback = self._trial_end_callback | ||
advisor.intermediate_metric_callback = self._intermediate_metric_callback | ||
advisor.final_metric_callback = self._final_metric_callback | ||
|
||
self._running_models: Dict[int, Model] = dict() | ||
|
||
def submit_models(self, *models: Model) -> None: | ||
for model in models: | ||
data = BaseGraphData(codegen.model_to_pytorch_script(model), | ||
model.training_config.module, model.training_config.kwargs) | ||
self._running_models[send_trial(data.dump())] = model | ||
|
||
def register_graph_listener(self, listener: AbstractGraphListener) -> None: | ||
self._listeners.append(listener) | ||
|
||
def _send_trial_callback(self, paramater: dict) -> None: | ||
for listener in self._listeners: | ||
listener.on_resource_used(0) # FIXME: find the real resource id | ||
|
||
def _request_trial_jobs_callback(self, num_trials: int) -> None: | ||
for listener in self._listeners: | ||
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id | ||
|
||
def _trial_end_callback(self, trial_id: int, success: bool) -> None: | ||
model = self._running_models[trial_id] | ||
if success: | ||
model.status = ModelStatus.Trained | ||
else: | ||
model.status = ModelStatus.Failed | ||
for listener in self._listeners: | ||
listener.on_training_end(model, success) | ||
|
||
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None: | ||
model = self._running_models[trial_id] | ||
model.intermediate_metrics.append(metrics) | ||
for listener in self._listeners: | ||
listener.on_intermediate_metric(model, metrics) | ||
|
||
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None: | ||
model = self._running_models[trial_id] | ||
model.metric = metrics | ||
for listener in self._listeners: | ||
listener.on_metric(model, metrics) | ||
|
||
def query_available_resource(self) -> List[WorkerInfo]: | ||
raise NotImplementedError # move the method from listener to here? | ||
|
||
@classmethod | ||
def trial_execute_graph(cls) -> None: | ||
""" | ||
Initialize the model, hand it over to trainer. | ||
""" | ||
graph_data = BaseGraphData.load(receive_trial_parameters()) | ||
with open('_generated_model.py', 'w') as f: | ||
f.write(graph_data.model_script) | ||
trainer_cls = utils.import_(graph_data.training_module) | ||
model_cls = utils.import_('_generated_model._model') | ||
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs) | ||
trainer_instance.fit() |
Oops, something went wrong.