Skip to content

Commit

Permalink
Set up codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Mar 7, 2023
1 parent b4ad9c1 commit 05aa87d
Show file tree
Hide file tree
Showing 25 changed files with 3,964 additions and 7 deletions.
202 changes: 195 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
dist/
params/
**/*.png

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

*.S
# C extensions
*.so

*.ll
.npm
# Distribution / packaging
.Python
env/
build/
build-*/
develop-eggs/
dist/
downloads/
Expand All @@ -27,12 +36,16 @@ share/python-wheels/
*.egg
MANIFEST

.conda/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Generated by python/gen_requirements.py
python/requirements/*.txt

# Installer logs
pip-log.txt
pip-delete-this-directory.txt
Expand Down Expand Up @@ -70,9 +83,11 @@ instance/

# Sphinx documentation
docs/_build/
docs/_staging/

# PyBuilder
target/
/target/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down Expand Up @@ -116,14 +131,187 @@ venv.bak/

# Rope project settings
.ropeproject
*~
*.pyc
*~
config.mk
config.cmake
Win32
*.dir
perf
*.wasm
.emscripten

# mkdocs documentation
/site
## IOS
DerivedData/

# mypy
.mypy_cache/
.dmypy.json
dmypy.json
## Java
*.class
jvm/*/target/
jvm/*/*/target/
jvm/native/*/generated
jvm/native/src/main/native/org_apache_tvm_native_c_api.h
*.worksheet
*.idea
*.iml
*.classpath
*.project
*.settings
*/node_modules/

## Various settings
*.pbxuser
!default.pbxuser
*.mode1v3
!default.mode1v3
*.mode2v3
!default.mode2v3
*.perspectivev3
!default.perspectivev3
xcuserdata/
.pkl_memoize_*

.emscripten*
.m2

# Compiled Dynamic libraries
*.so
*.dylib
*.dll

# Compiled Object files
*.slo
*.lo
*.o
*.obj

# Precompiled Headers
*.gch
*.pch

# Compiled Static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

## Other
*.moved-aside
*.xccheckout
*.xcscmblueprint
.DS_Store
tags
cscope*
*.lock

# vim temporary files
*.swp
*.swo

# TVM generated code
perf
.bash_history
# *.json
*.params
*.ro
*.onnx
*.h5
synset.txt
cat.jpg
cat.png
docs.tgz
cat.png
*.mlmodel
tvm_u.*
tvm_t.*
# Mac OS X
.DS_Store

# Jetbrain
.idea
.ipython
.jupyter
.nv
.pylint.d
.python_history
.pytest_cache
.local
cmake-build-debug

# Visual Studio
.vs

# Pyre type checker
# Visual Studio Code
.vscode

# tmp file
.nfs*

# keys
*.pem
*.p12
*.pfx
*.cer
*.crt
*.der

# patch sentinel
patched.txt

# Python type checking
.mypy_cache/
.pyre/

# pipenv files
Pipfile
Pipfile.lock

# conda package artifacts
conda/Dockerfile.cuda*
conda/pkg
.node_repl_history
# nix files
.envrc
*.nix

# Docker files
.sudo_as_admin_successful

# Downloaded models/datasets
.tvm_test_data
.dgl
.caffe2

# Local docs build
_docs/
jvm/target
.config/configstore/
.ci-py-scripts/

# Generated Hexagon files
src/runtime/hexagon/rpc/hexagon_rpc.h
src/runtime/hexagon/rpc/hexagon_rpc_skel.c
src/runtime/hexagon/rpc/hexagon_rpc_stub.c

# Local tvm-site checkout
tvm-site/

# Generated docs files
gallery/how_to/work_with_microtvm/micro_tvmc.py

# Test sample data files
!tests/python/ci/sample_prs/*.json

# Used in CI to communicate between Python and Jenkins
.docker-image-names/

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "3rdparty/tokenizers-wasm"]
path = 3rdparty/tokenizers-wasm
url = https://github.com/mithril-security/tokenizers-wasm
1 change: 1 addition & 0 deletions 3rdparty/tokenizers-wasm
Submodule tokenizers-wasm added at a2602d
119 changes: 119 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Dict, List, Tuple

import os
import argparse
import pickle
import web_stable_diffusion.trace as trace
import web_stable_diffusion.utils as utils

import tvm
from tvm import relax
from tvm.contrib import tvmjs


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument("--target", type=str, default="apple/m2-gpu")
args.add_argument("--db-path", type=str, default="log_db/")
args.add_argument("--from-checkpt", type=str, choices=["deploy"], default="")
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--use-cache",
type=int,
default=1,
help="Whether to use previously pickled IRModule and skip trace.",
)

args.add_argument("--show-build-stage", action="store_true", default=False)
parsed = args.parse_args()

if parsed.target == "webgpu":
parsed.target = tvm.target.Target(
"webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm"
)
else:
parsed.target = tvm.target.Target(parsed.target, host="llvm")
return parsed


def trace_models(
device_str: str,
) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
clip = trace.clip_to_text_embeddings(pipe)
unet = trace.unet_latents_to_noise_pred(pipe, device_str)
vae = trace.vae_to_image(pipe)
concat_embeddings = trace.concat_embeddings()
image_to_rgba = trace.image_to_rgba()
scheduler_steps = trace.scheduler_steps()

mod = utils.merge_irmodules(
clip, unet, vae, concat_embeddings, image_to_rgba, scheduler_steps
)
return relax.frontend.detach_params(mod)


def legalize_and_lift_params(
mod: tvm.IRModule, model_params: Dict[str, List[tvm.nd.NDArray]], args: Dict
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
model_names = ["clip", "unet", "vae"]
scheduler_func_names = [f"scheduler_step_{i}" for i in range(5)]
entry_funcs = (
model_names + scheduler_func_names + ["image_to_rgba", "concat_embeddings"]
)

mod = relax.pipeline.get_pipeline()(mod)
mod = relax.transform.RemoveUnusedFunctions(entry_funcs)(mod)
mod = relax.transform.LiftTransformParams()(mod)
if args.show_build_stage:
mod.show()
mod_transform, mod_deploy = utils.split_transform_deploy_mod(
mod, model_names, entry_funcs
)

trace.compute_save_scheduler_consts(args.artifact_path)
new_params = utils.transform_params(mod_transform, model_params)
utils.save_params(new_params, args.artifact_path)
return mod_deploy


def build(mod: tvm.IRModule, args: Dict) -> None:
from tvm import meta_schedule as ms

db = ms.database.create(work_dir=args.db_path)
with args.target, db, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod)
ex = relax.build(mod_deploy, args.target)

target_kind = args.target.kind.default_keys[0]

if target_kind == "webgpu":
output_filename = f"stable_diffusion_{target_kind}.wasm"
tvmjs.export_runtime(f"{args.artifact_path}")
else:
output_filename = f"stable_diffusion_{target_kind}.so"
ex.export_library(os.path.join(args.artifact_path, output_filename))


if __name__ == "__main__":
ARGS = _parse_args()
os.makedirs(ARGS.artifact_path, exist_ok=True)
torch_dev_key = utils.detect_available_torch_device()
cache_path = os.path.join(ARGS.artifact_path, "mod_cache_before_build.pkl")
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
if not use_cache:
mod, params = trace_models(torch_dev_key)
mod = legalize_and_lift_params(mod, params, ARGS)
with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
else:
print(
f"Load cached module from {cache_path} and skip tracing. "
"You can use --use-cache=0 to retrace"
)
mod = pickle.load(open(cache_path, "rb"))
build(mod, ARGS)
Loading

0 comments on commit 05aa87d

Please sign in to comment.