Skip to content

Commit

Permalink
[Retiarii] Base execution engine, codegen and trainer (microsoft#3059)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Nov 5, 2020
1 parent 60b2a7a commit 002af91
Show file tree
Hide file tree
Showing 26 changed files with 1,090 additions and 3 deletions.
1 change: 1 addition & 0 deletions nni/retiarii/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .execution import *
from .graph import *
from .mutator import *
from .operation import *
1 change: 1 addition & 0 deletions nni/retiarii/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pytorch import model_to_pytorch_script
103 changes: 103 additions & 0 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import *

from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell


def model_to_pytorch_script(model: Model) -> str:
graphs = [graph_to_pytorch_model(name, cell) for name, cell in model.graphs.items()]
return _PyTorchScriptTemplate.format('\n\n'.join(graphs)).strip()


def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))


def _format_inputs(node: Node) -> str:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if node.graph.input_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(node.graph.input_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return ', '.join(inputs)


def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here

# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if node.operation:
node_codes.append(node.operation.to_init_code(node.name))

if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)

edge_codes = []

for node in nodes:
if node.operation:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))

output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'

linebreak = '\n '
return _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)


# TODO: handle imports

_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
{}
'''

_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
106 changes: 106 additions & 0 deletions nni/retiarii/codegen/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# pylint: skip-file

"""
FIXME
This file is inherited from last version.
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""

from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *


def graph_to_tensorflow_script(graph: Graph) -> str:
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()


def _sort_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_idx is None for edge in edges):
return edges
if all(isinstance(edge.tail_idx, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_idx))
if [edge.tail_idx for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))

def _format_inputs(node: Node) -> str:
edges = _sort_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_idx, int)
if node.graph.input_names is not None:
inputs.append(node.graph.input_names[edge.head_idx])
else:
inputs.append('_inputs[{}]'.format(edge.head_idx))
else:
if edge.head_idx is None:
inputs.append('{}'.format(edge.head.name))
else:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
return ', '.join(inputs)


def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
nodes = graph.topo_sort()

# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if isinstance(node, Cell):
node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
else:
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))

edge_codes = []

for node in nodes:
inputs = _format_inputs(node)
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))

output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'

if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)

linebreak = '\n '
return _TensorFlowModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)


_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K
import sdk.custom_ops_tf as CUSTOM
{}
'''

_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
def __init__(self):
super().__init__()
{nodes}
def call(self, {inputs}):
{edges}
return {outputs}
'''
1 change: 1 addition & 0 deletions nni/retiarii/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .api import *
51 changes: 51 additions & 0 deletions nni/retiarii/execution/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import time
from typing import *

from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .interface import *
from .listener import DefaultListener

_execution_engine = None
_default_listener = None

__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources']


def get_execution_engine() -> BaseExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine
if _execution_engine is None:
_execution_engine = BaseExecutionEngine()
return _execution_engine


def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener


def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)


def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break


def query_available_resources() -> List[WorkerInfo]:
listener = get_and_register_default_listener(get_execution_engine())
return listener.resources
103 changes: 103 additions & 0 deletions nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import *

from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor


class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None:
self.model_script = model_script
self.training_module = training_module
self.training_kwargs = training_kwargs

def dump(self) -> dict:
return {
'model_script': self.model_script,
'training_module': self.training_module,
'training_kwargs': self.training_kwargs
}

@staticmethod
def load(data):
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs'])


class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is yet to be implemented.
"""

def __init__(self) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
"""
self._listeners: List[AbstractGraphListener] = []

# register advisor callbacks
advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback

self._running_models: Dict[int, Model] = dict()

def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
model.training_config.module, model.training_config.kwargs)
self._running_models[send_trial(data.dump())] = model

def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)

def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
listener.on_resource_used(0) # FIXME: find the real resource id

def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id

def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)

def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)

def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)

def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?

@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_('_generated_model._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit()
Loading

0 comments on commit 002af91

Please sign in to comment.