Skip to content

Commit

Permalink
Leverage async shader loading (mlc-ai#14)
Browse files Browse the repository at this point in the history
Speeds up loading time in windows.
  • Loading branch information
tqchen authored Mar 22, 2023
1 parent 51bc67c commit 2cbd64d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
3 changes: 1 addition & 2 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import web_stable_diffusion.utils as utils
from platform import system

import GPUtil
import tvm
from tvm import relax

Expand All @@ -31,7 +30,7 @@ def _parse_args():
if system() == "Darwin":
target = tvm.target.Target("apple/m1-gpu")
else:
has_gpu = len(GPUtil.getGPUs()) > 0
has_gpu = tvm.cuda().exist
target = tvm.target.Target("cuda" if has_gpu else "llvm")
print(f"Automatically configuring target: {target}")
parsed.target = tvm.target.Target(target, host="llvm")
Expand Down
15 changes: 12 additions & 3 deletions web/stable_diffusion.js
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ class StableDiffusionPipeline {
}
return this.tvm.empty([1, this.maxTokenLength], "int32", this.device).copyFrom(inputIDs);
}

/**
* async preload webgpu pipelines when possible.
*/
async asyncLoadWebGPUPiplines() {
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
}

/**
* Run generation pipeline.
*
Expand Down Expand Up @@ -458,11 +466,11 @@ class StableDiffusionInstance {
}

this.tvm = tvm;
function fetchProgressCallback(report) {
function initProgressCallback(report) {
document.getElementById("progress-tracker-label").innerHTML = report.text;
document.getElementById("progress-tracker-progress").value = (report.fetchedBytes / report.totalBytes) * 100;
document.getElementById("progress-tracker-progress").value = report.progress * 100;
}
tvm.registerFetchProgressCallback(fetchProgressCallback);
tvm.registerInitProgressCallback(initProgressCallback);
if (!cacheUrl.startsWith("http")) {
cacheUrl = new URL(cacheUrl, document.URL).href;
}
Expand All @@ -488,6 +496,7 @@ class StableDiffusionInstance {
this.pipeline = this.tvm.withNewScope(() => {
return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata);
});
await this.pipeline.asyncLoadWebGPUPiplines();
}

/**
Expand Down

0 comments on commit 2cbd64d

Please sign in to comment.