Skip to content

Commit

Permalink
[Debug] Add RPC lib compare tool (mlc-ai#6)
Browse files Browse the repository at this point in the history
This PR adds a few debug compare tool for pair testing libs
  • Loading branch information
tqchen authored Mar 15, 2023
1 parent 09b71a2 commit d5f1708
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 13 deletions.
27 changes: 23 additions & 4 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _parse_args():
return parsed


def debug_dump(mod, name, args):
def debug_dump_script(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
Expand All @@ -54,6 +54,24 @@ def debug_dump(mod, name, args):
print(f"Dump mod to {dump_path}")


def debug_dump_shader(ex, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
target_kind = args.target.kind.default_keys[0]
suffix_map = {
"webgpu": ".wgsl",
"cuda": ".cu",
"metal": ".mtl",
}
suffix = suffix_map.get(target_kind, ".txt")
dump_path = os.path.join(args.artifact_path, "debug", name + suffix)
source = ex.mod.imported_modules[0].imported_modules[0].get_source()
with open(dump_path, "w") as outfile:
outfile.write(source)
print(f"Dump shader to {dump_path}")


def trace_models(
device_str: str,
) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]:
Expand Down Expand Up @@ -95,12 +113,11 @@ def legalize_and_lift_params(
mod = relax.pipeline.get_pipeline()(mod)
mod = relax.transform.RemoveUnusedFunctions(entry_funcs)(mod)
mod = relax.transform.LiftTransformParams()(mod)

mod_transform, mod_deploy = utils.split_transform_deploy_mod(
mod, model_names, entry_funcs
)

debug_dump(mod_transform, "mod_lift_params.py", args)
debug_dump_script(mod_transform, "mod_lift_params.py", args)

trace.compute_save_scheduler_consts(args.artifact_path)
new_params = utils.transform_params(mod_transform, model_params)
Expand All @@ -115,7 +132,7 @@ def build(mod: tvm.IRModule, args: Dict) -> None:
with args.target, db, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod)

debug_dump(mod_deploy, "mod_build_stage.py", args)
debug_dump_script(mod_deploy, "mod_build_stage.py", args)

ex = relax.build(mod_deploy, args.target)

Expand All @@ -125,6 +142,8 @@ def build(mod: tvm.IRModule, args: Dict) -> None:
output_filename = f"stable_diffusion_{target_kind}.wasm"
else:
output_filename = f"stable_diffusion_{target_kind}.so"

debug_dump_shader(ex, f"stable_diffusion_{target_kind}", args)
ex.export_library(os.path.join(args.artifact_path, output_filename))


Expand Down
106 changes: 106 additions & 0 deletions tests/debug/compare_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List

import argparse
import os

import numpy as np
import tvm
from tvm import relax
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
from web_stable_diffusion import utils
from web_stable_diffusion.rpc_testing import connect_to_proxy


def load_checkpt(args, name):
return np.load(os.path.join(args.artifact_path, "debug", f"{name}.npy"))


class LibCompare(LibCompareVMInstrument):
def __init__(self, mod, device, time_eval, skip_rounds):
super().__init__(mod, device, True)
self.time_eval = time_eval
self.time_eval_results = {}
self.skip_rounds = skip_rounds

def skip_instrument(self, func, name, before_run, ret_val, *args):
if self.counter < self.skip_rounds:
self.counter += 1
print(f"[{self.counter}] Skip validating {name}..")
return True
return False

def compare(self,
name : str,
ref_args: List[tvm.nd.NDArray],
new_args: List[tvm.nd.NDArray],
ret_indices: List[int]):

super().compare(name, ref_args, new_args, ret_indices)
if self.time_eval and name not in self.time_eval_results:
res = self.mod.time_evaluator(name)(*new_args)
self.time_eval_results[name] = res
print(f"Time-eval result {name} on {self.device}: {res}")


class TestState:
def __init__(self, args):
assert args.primary_device != "webgpu"
self.primary_device = tvm.device(args.primary_device)
ex = tvm.runtime.load_module(
f"{args.artifact_path}/stable_diffusion_{args.primary_device}.so"
)
self.vm = relax.VirtualMachine(ex, self.primary_device)
if args.cmp_device == "webgpu":
self.sess = connect_to_proxy(
f"{args.artifact_path}/stable_diffusion_webgpu.wasm"
)
self.lib = self.sess.system_lib()
self.cmp_device = self.sess.webgpu()
else:
self.sess = None
self.lib = tvm.runtime.load_module(
f"{args.artifact_path}/stable_diffusion_{args.cmp_device}.so"
)
self.cmp_device = tvm.device(args.cmp_device)
self.const_params_dict = utils.load_params(
args.artifact_path, self.primary_device)
self.cmp_instrument = LibCompare(
self.lib, self.cmp_device,
time_eval=args.time_eval, skip_rounds=args.skip_rounds)
self.vm.set_instrument(self.cmp_instrument)


def main_vae(args):
state = TestState(args)
vae_input = tvm.nd.array(load_checkpt(args, "vae_input"), state.primary_device)
state.vm["vae"](vae_input, state.const_params_dict["vae"])


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--primary-device", type=str, default="auto")
args.add_argument("--cmp-device", type=str, required=True)
args.add_argument("--stage", type=str, choices=["unet", "vae"], required=True)
args.add_argument("--counter", type=int, default=0)
args.add_argument("--time-eval", default=False, action="store_true")
args.add_argument("--skip-rounds", type=int, default=0)
parsed = args.parse_args()

if parsed.primary_device == "auto":
if tvm.cuda().exist:
parsed.primary_device = "cuda"
elif tvm.metal().exist:
parsed.primary_device = "metal"
else:
raise ValueError("Cannot auto deduce device-name, please set it")
return parsed


if __name__ == "__main__":
args = _parse_args()
if args.stage == "vae":
main_vae(args)
else:
raise ValueError(f"Unknown stage {args.stage}")
print(f"All pass running stage {args.stage}, counter={args.counter}")
30 changes: 21 additions & 9 deletions web_stable_diffusion/rpc_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ def _convert_return(data):
return wrapped_f


def connect_to_proxy(wasm_path):
"""Connect to defalt proxy
Parameters
----------
wasm_path: str
The path to wasm
"""
proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1")
proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090"))
wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
proxy_host,
proxy_port,
key="wasm",
session_constructor_args=["rpc.WasmSession", wasm_binary],
)
return remote


class WebGPUDebugSession(RPCBaseDebugSession):
"""Remote debug session to handle webgpu.
Expand All @@ -129,15 +149,7 @@ class WebGPUDebugSession(RPCBaseDebugSession):
"""

def __init__(self, wasm_path):
proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1")
proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090"))
wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
proxy_host,
proxy_port,
key="wasm",
session_constructor_args=["rpc.WasmSession", wasm_binary],
)
remote = connect_to_proxy(wasm_path)
super(WebGPUDebugSession, self).__init__(
remote, remote.system_lib(), remote.webgpu()
)

0 comments on commit d5f1708

Please sign in to comment.