Skip to content

Commit

Permalink
[retiarii] fix test (microsoft#3197)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Dec 16, 2020
1 parent a0e2f8e commit 0f0c628
Show file tree
Hide file tree
Showing 14 changed files with 29 additions and 138 deletions.
4 changes: 1 addition & 3 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,12 @@ def handle_single_node(node):
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::min':
print('zql: ', sm_graph)
exit(1)
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/debug_configs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# these should be experiment config in release
# we will support tensorflow in future release

framework = 'pytorch'
2 changes: 1 addition & 1 deletion nni/retiarii/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .graph import MetricData

_logger = logging.getLogger('nni.msg_dispatcher_base')
_logger = logging.getLogger(__name__)


class RetiariiAdvisor(MsgDispatcherBase):
Expand Down
4 changes: 0 additions & 4 deletions test/ut/retiarii/advisor_entry.py

This file was deleted.

31 changes: 0 additions & 31 deletions test/ut/retiarii/debug_strategy.py

This file was deleted.

1 change: 0 additions & 1 deletion test/ut/retiarii/fake_search_space.json

This file was deleted.

18 changes: 9 additions & 9 deletions test/ut/retiarii/mnist-tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"outputs": ["metric"],

"nodes": {
"stem": {"type": "_cell", "cell": "stem"},
"flatten": {"type": "Flatten"},
"fc1": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}},
"fc2": {"type": "Dense", "parameters": {"units": 10}},
"softmax": {"type": "Softmax"}
"stem": {"operation": {"type": "_cell", "parameters": {}, "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten", "parameters": {}}},
"fc1": {"operation": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"units": 10}}},
"softmax": {"operation": {"type": "Softmax", "parameters": {}}}
},

"edges": [
Expand All @@ -23,10 +23,10 @@

"stem": {
"nodes": {
"conv1": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}},
"pool1": {"type": "MaxPool2D", "parameters": {"pool_size": 2}},
"conv2": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}},
"pool2": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}
"conv1": {"operation": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}}},
"pool1": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}},
"conv2": {"operation": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}}},
"pool2": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}}
},

"edges": [
Expand Down
40 changes: 0 additions & 40 deletions test/ut/retiarii/mnist.json

This file was deleted.

18 changes: 0 additions & 18 deletions test/ut/retiarii/nni.yaml

This file was deleted.

18 changes: 0 additions & 18 deletions test/ut/retiarii/nni_cgo.yaml

This file was deleted.

18 changes: 12 additions & 6 deletions test/ut/retiarii/test_cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import unittest
import logging
import time
import torch
import torch

from pathlib import Path

from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
Expand All @@ -22,7 +24,8 @@


def _load_mnist(n_models: int = 1):
with open('converted_mnist_pytorch.json') as f:
path = Path(__file__).parent / 'converted_mnist_pytorch.json'
with open(path) as f:
mnist_model = Model._load(json.load(f))
if n_models == 1:
return mnist_model
Expand All @@ -38,6 +41,7 @@ def test_submit_models(self):
os.environ['CGO'] = 'true'
os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol, platform
import nni.runtime.platform.test as tt
protocol._out_file = open('generated/debug_protocol_out_file.py', 'wb')
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')

Expand All @@ -50,15 +54,15 @@ def test_submit_models(self):
params = json.loads(data)
params['parameters']['training_kwargs']['max_steps'] = 100

platform.test.init_params(params)
tt.init_params(params)

trial_thread = threading.Thread(target=CGOExecutionEngine.trial_execute_graph())
trial_thread.start()
last_metric = None
while True:
time.sleep(1)
if platform.test._last_metric:
metric = platform.test.get_last_metric()
if tt._last_metric:
metric = tt.get_last_metric()
if metric == last_metric:
continue
advisor.handle_report_metric_data(metric)
Expand All @@ -75,4 +79,6 @@ def test_submit_models(self):
if __name__ == '__main__':
#CGOEngineTest().test_dedup_input()
#CGOEngineTest().test_submit_models()
unittest.main()
#unittest.main()
# TODO: fix ut
pass
5 changes: 4 additions & 1 deletion test/ut/retiarii/test_dedup_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import time

from pathlib import Path

from nni.retiarii.execution.cgo_engine import CGOExecutionEngine
from nni.retiarii.execution.logical_optimizer.logical_plan import LogicalPlan
from nni.retiarii.execution.logical_optimizer.opt_dedup_input import DedupInputOptimizer
Expand All @@ -19,7 +21,8 @@
from nni.retiarii.utils import import_

def _load_mnist(n_models: int = 1):
with open('converted_mnist_pytorch.json') as f:
path = Path(__file__).parent / 'converted_mnist_pytorch.json'
with open(path) as f:
mnist_model = Model._load(json.load(f))
if n_models == 1:
return mnist_model
Expand Down
4 changes: 0 additions & 4 deletions test/ut/retiarii/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,10 @@ def _test_file(json_path):
graph['inputs'] = None
if 'outputs' not in graph:
graph['outputs'] = None
for node_name, node in graph['nodes'].items():
if 'parameters' not in node:
node['parameters'] = {}

# debug output
#json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
#json.dump(dump_ir, open('_dump.json', 'w'), indent=4)

assert orig_ir == dump_ir


Expand Down
2 changes: 1 addition & 1 deletion test/ut/retiarii/test_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def mutate(self, model):


def test_dry_run():
candidates = mutator.dry_run(model0)
candidates, _ = mutator.dry_run(model0)
assert len(candidates) == 2
assert candidates[0] == [max_pool, avg_pool, global_pool]
assert candidates[1] == [max_pool, avg_pool, global_pool]
Expand Down

0 comments on commit 0f0c628

Please sign in to comment.