-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
171 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,58 @@ | ||
import onnx | ||
import numpy as np | ||
import os.path as osp | ||
import time | ||
import tvm | ||
from tvm import relay, auto_scheduler | ||
import tvm.relay.testing | ||
from tvm.contrib import graph_executor | ||
|
||
prefix = "/home/v-yiningshi/learn_tvm/testing/temp/bert" | ||
target = tvm.target.cuda(arch="sm_70") | ||
# import tensorflow.compat.v1 as tf | ||
# pt_model = open(osp.join(prefix, "classifier.pb"), "rb") | ||
# graph_def = tf.GraphDef() | ||
# graph_def.ParseFromString(pt_model.read()) | ||
# mod, params = relay.frontend.from_tensorflow(graph_def, "NHWC") | ||
# feed_dict = dict(np.load(osp.join(prefix, "inputs.npz"), allow_pickle=True)) | ||
# shape_dict = {key: value.shape for key, value in feed_dict.items()} | ||
onnx_model = onnx.load(osp.join(prefix, "model.onnx")) | ||
mod, params = relay.frontend.from_onnx(onnx_model) | ||
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) | ||
log_file = osp.join(prefix, "ansor_tune.log") | ||
|
||
for idx, task in enumerate(tasks): | ||
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) | ||
print(task.compute_dag) | ||
|
||
def run_tuning(): | ||
print("Begin tuning...") | ||
measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10, device=3) | ||
|
||
tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=log_file) | ||
tune_option = auto_scheduler.TuningOptions( | ||
num_measure_trials=len(tasks) * 512, | ||
runner=measure_ctx.runner, | ||
measure_callbacks=[auto_scheduler.RecordToFile(log_file)], | ||
) | ||
|
||
tuner.tune(tune_option) | ||
|
||
# run_tuning() | ||
|
||
# Compile with the history best | ||
print("Compile...") | ||
with auto_scheduler.ApplyHistoryBest(log_file): | ||
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): | ||
lib = relay.build(mod, target=target, params=params) | ||
|
||
# Create graph executor | ||
dev = tvm.device(str(target), 3) | ||
module = graph_executor.GraphModule(lib["default"](dev)) | ||
|
||
# Evaluate | ||
print("Evaluate inference time cost...") | ||
print(module.benchmark(dev, min_repeat_ms=500, end_to_end=False)) | ||
import argparse | ||
|
||
def run_ansor(prefix, device, skip_tuning): | ||
target = tvm.target.cuda(arch="sm_70") | ||
onnx_model = onnx.load(osp.join(prefix, "model.onnx")) | ||
mod, params = relay.frontend.from_onnx(onnx_model) | ||
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) | ||
log_file = osp.join(prefix, "ansor_tune.log") | ||
|
||
for idx, task in enumerate(tasks): | ||
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) | ||
print(task.compute_dag) | ||
|
||
num_trials = len(tasks) * 800 | ||
if osp.exists(log_file): | ||
with open(log_file, "r") as f: | ||
cur_records = len(f.readlines()) | ||
num_trials -= cur_records | ||
if num_trials > 0 and not skip_tuning: | ||
print("Begin tuning...") | ||
measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=300, timeout=10, device=device) | ||
tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=log_file) | ||
tune_option = auto_scheduler.TuningOptions( | ||
num_measure_trials=num_trials, | ||
runner=measure_ctx.runner, | ||
measure_callbacks=[auto_scheduler.RecordToFile(log_file)], | ||
) | ||
tuner.tune(tune_option) | ||
|
||
# Compile with the history best | ||
print("Compile...") | ||
with auto_scheduler.ApplyHistoryBest(log_file): | ||
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): | ||
lib = relay.build(mod, target=target, params=params) | ||
|
||
# Create graph executor | ||
dev = tvm.device(str(target), device) | ||
module = graph_executor.GraphModule(lib["default"](dev)) | ||
|
||
# Evaluate | ||
print("Evaluate inference time cost...") | ||
print(module.benchmark(dev, min_repeat_ms=500, end_to_end=False)) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--prefix', type=str, default="temp") | ||
parser.add_argument('--device', type=int, default=0) | ||
parser.add_argument('--skip', action="store_true") | ||
args = parser.parse_args() | ||
start_time = time.time() | ||
run_ansor(args.prefix, args.device, args.skip) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import argparse | ||
import os | ||
import time | ||
import os.path as osp | ||
import tempfile | ||
|
||
def load_graph(onnx_file): | ||
import onnx | ||
from onnx_tf.backend import prepare | ||
onnx_model = onnx.load(onnx_file) | ||
tf_rep = prepare(onnx_model, device="cuda") | ||
exported = tempfile.TemporaryDirectory() | ||
tf_rep.export_graph(exported.name) | ||
return exported | ||
|
||
def run_tf(prefix, xla=False): | ||
if xla: | ||
tf.config.optimizer.set_jit(True) | ||
exported = load_graph(osp.join(prefix, "model.onnx")) | ||
feed_dict = dict(np.load(osp.join(prefix, "inputs.npz"), allow_pickle=True)) | ||
|
||
saved_model_loaded = tf.saved_model.load( | ||
exported.name, tags=[tf.saved_model.SERVING]) | ||
graph_func = saved_model_loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] | ||
|
||
def get_runtime(): | ||
tic = time.time() | ||
_ = graph_func(**feed_dict) | ||
return (time.time() - tic) * 1000 | ||
_ = [get_runtime() for i in range(200)] # warmup | ||
times = [get_runtime() for i in range(800)] | ||
print(np.mean(times), np.min(times), np.max(times)) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--xla', action="store_true") | ||
parser.add_argument('--prefix', type=str, default="temp") | ||
parser.add_argument('--device', type=int, default=0) | ||
args = parser.parse_args() | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) | ||
run_tf(args.prefix, xla=args.xla) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import tensorflow as tf | ||
from tensorflow.python.compiler.tensorrt import trt_convert as trt | ||
import numpy as np | ||
import argparse | ||
import os | ||
import time | ||
import os.path as osp | ||
import tempfile | ||
|
||
def load_graph(onnx_file): | ||
import onnx | ||
from onnx_tf.backend import prepare | ||
onnx_model = onnx.load(onnx_file) | ||
tf_rep = prepare(onnx_model, device="cuda") | ||
exported = tempfile.TemporaryDirectory() | ||
tf_rep.export_graph(exported.name) | ||
return exported | ||
|
||
def run_tf(prefix): | ||
exported = load_graph(osp.join(prefix, "model.onnx")) | ||
feed_dict = dict(np.load(osp.join(prefix, "inputs.npz"), allow_pickle=True)) | ||
|
||
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS | ||
converter = trt.TrtGraphConverterV2( | ||
input_saved_model_dir=exported.name, | ||
conversion_params=conversion_params) | ||
converter.convert() | ||
def my_input_fn(): | ||
yield tuple(feed_dict.values()) | ||
converter.build(input_fn=my_input_fn) | ||
convert_exported = tempfile.TemporaryDirectory() | ||
converter.save(convert_exported.name) | ||
|
||
saved_model_loaded = tf.saved_model.load( | ||
convert_exported.name, tags=[tf.saved_model.SERVING]) | ||
graph_func = saved_model_loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] | ||
|
||
def get_runtime(): | ||
tic = time.time() | ||
_ = graph_func(**feed_dict) | ||
return (time.time() - tic) * 1000 | ||
_ = [get_runtime() for i in range(200)] # warmup | ||
times = [get_runtime() for i in range(800)] | ||
print(np.mean(times), np.min(times), np.max(times)) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--prefix', type=str, default="temp") | ||
parser.add_argument('--device', type=int, default=0) | ||
args = parser.parse_args() | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) | ||
run_tf(args.prefix) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters