Skip to content

Commit

Permalink
[Web] Multistep DPM-solver for web side (mlc-ai#3)
Browse files Browse the repository at this point in the history
This PR brings the DPM-solver to web side. With this PR, in the demo
page we are able to select between two different schedulers.
  • Loading branch information
MasterJH5574 authored Mar 12, 2023
1 parent 827f615 commit 23f4add
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 37 deletions.
3 changes: 2 additions & 1 deletion scripts/build_site.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ echo "Copy files..."
cp web/stable_diffusion.html site/_includes
cp web/stable_diffusion.js site/dist

cp dist/scheduler_consts.json site/dist
cp dist/scheduler_pndm_consts.json site/dist
cp dist/scheduler_dpm_solver_multistep_consts.json site/dist
cp dist/stable_diffusion_webgpu.wasm site/dist

cp dist/tvmjs_runtime.wasi.js site/dist
Expand Down
3 changes: 2 additions & 1 deletion scripts/rpc_debug_deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ cp web/stable_diffusion.html ${TVM_HOME}/web/dist/www/rpc_plugin.html
cp web/stable_diffusion.js ${TVM_HOME}/web/dist/www/dist/
cp web/local-config.json ${TVM_HOME}/web/dist/www/stable-diffusion-config.json

cp dist/scheduler_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/scheduler_pndm_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/scheduler_dpm_solver_multistep_consts.json ${TVM_HOME}/web/dist/www/dist/
cp dist/stable_diffusion_webgpu.wasm ${TVM_HOME}/web/dist/www/dist/
cp -rf dist/tokenizers-wasm ${TVM_HOME}/web/dist/www/dist/

Expand Down
5 changes: 4 additions & 1 deletion web/gh-page-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"schedulerConstUrl": "dist/scheduler_consts.json",
"schedulerConstUrl": [
"dist/scheduler_dpm_solver_multistep_consts.json",
"dist/scheduler_pndm_consts.json"
],
"wasmUrl": "dist/stable_diffusion_webgpu.wasm",
"cacheUrl": "https://huggingface.co/mlc-ai/web-sd/resolve/main/web-sd-shards-v1-5/",
"tokenizer": "openai/clip-vit-large-patch14"
Expand Down
5 changes: 4 additions & 1 deletion web/local-config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
{
"schedulerConstUrl": "dist/scheduler_consts.json",
"schedulerConstUrl": [
"dist/scheduler_dpm_solver_multistep_consts.json",
"dist/scheduler_pndm_consts.json"
],
"wasmUrl": "dist/stable_diffusion_webgpu.wasm",
"cacheUrl": "web-sd-shards-v1-5/",
"tokenizer": "openai/clip-vit-large-patch14"
Expand Down
8 changes: 8 additions & 0 deletions web/stable_diffusion.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
</div>

<div>
Select scheduler -
<select name="scheduler" id="schedulerId">
<option value="0">Multi-step DPM Solver (20 steps)</option>
<option value="1">PNDM (50 steps)</option>
</select>

<br>

Render intermediate steps (may slow down execution) -
<select name="vae-cycle" id="vaeCycle">
<option value="-1">No</option>
Expand Down
168 changes: 135 additions & 33 deletions web/stable_diffusion.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TVMPNDMScheduler {

// prebuild constants
// principle: always detach for class members
// to avoid recyling output scope.
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
Expand Down Expand Up @@ -42,7 +42,7 @@ class TVMPNDMScheduler {
for (let i = 0; i < 5; ++i) {
this.schedulerFunc.push(
tvm.detachFromCurrentScope(
vm.getFunction("scheduler_step_" + i.toString())
vm.getFunction("pndm_scheduler_step_" + i.toString())
)
);
}
Expand Down Expand Up @@ -101,6 +101,85 @@ class TVMPNDMScheduler {
}
}

/**
* Wrapper to handle multistep DPM-solver scheduler
*/
class TVMDPMSolverMultistepScheduler {
constructor(schedulerConsts, latentShape, tvm, device, vm) {
this.timestep = [];
this.alpha = [];
this.sigma = [];
this.c0 = [];
this.c1 = [];
this.c2 = [];
this.lastModelOutput = undefined;
this.convertModelOutputFunc = undefined;
this.stepFunc = undefined;
this.tvm = tvm;

// prebuild constants
// principle: always detach for class members
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
tvm.detachFromCurrentScope(
tvm.empty([], dtype, device).copyFrom([input[t]])
)
);
}
}
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]);
loadConsts(this.alpha, "float32", schedulerConsts["alpha"]);
loadConsts(this.sigma, "float32", schedulerConsts["sigma"]);
loadConsts(this.c0, "float32", schedulerConsts["c0"]);
loadConsts(this.c1, "float32", schedulerConsts["c1"]);
loadConsts(this.c2, "float32", schedulerConsts["c2"]);

this.lastModelOutput = this.tvm.detachFromCurrentScope(
this.tvm.empty(latentShape, "float32", device)
)
this.convertModelOutputFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_convert_model_output")
)
this.stepFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_step")
)
}

dispose() {
for (let t = 0; t < this.timestep.length; ++t) {
this.timestep[t].dispose();
this.alpha[t].dispose();
this.sigma[t].dispose();
this.c0[t].dispose();
this.c1[t].dispose();
this.c2[t].dispose();
}

this.lastModelOutput.dispose();
this.convertModelOutputFunc.dispose();
this.stepFunc.dispose();
}

step(modelOutput, sample, counter) {
modelOutput = this.convertModelOutputFunc(sample, modelOutput, this.alpha[counter], this.sigma[counter])
const prevLatents = this.stepFunc(
sample,
modelOutput,
this.lastModelOutput,
this.c0[counter],
this.c1[counter],
this.c2[counter],
);
this.lastModelOutput = this.tvm.detachFromCurrentScope(
modelOutput
);

return prevLatents;
}
}

class StableDiffusionPipeline {
constructor(tvm, tokenizer, schedulerConsts, cacheMetadata) {
if (cacheMetadata == undefined) {
Expand Down Expand Up @@ -181,10 +260,20 @@ class StableDiffusionPipeline {
* @param prompt Input prompt.
* @param negPrompt Input negative prompt.
* @param progressCallback Callback to check progress.
* @param schedulerId The integer ID of the scheduler to use.
* - 0 for multi-step DPM solver,
* - 1 for PNDM solver.
* @param vaeCycle optionally draw VAE result every cycle iterations.
* @param beginRenderVae Begin rendering VAE after skipping these warmup runs.
*/
async generate(prompt, negPrompt="", progressCallback = undefined, vaeCycle = -1, beginRenderVae = 10) {
async generate(
prompt,
negPrompt = "",
progressCallback = undefined,
schedulerId = 0,
vaeCycle = -1,
beginRenderVae = 10
) {
// Principle: beginScope/endScope in synchronized blocks,
// this helps to recycle intermediate memories
// detach states that needs to go across async boundaries.
Expand All @@ -194,10 +283,21 @@ class StableDiffusionPipeline {
this.tvm.beginScope();
// get latents
const latentShape = [1, 4, 64, 64];
scheduler = new TVMPNDMScheduler(
this.schedulerConsts, latentShape, this.tvm, this.device, this.vm);

var unetNumSteps;
if (schedulerId == 0) {
scheduler = new TVMDPMSolverMultistepScheduler(
this.schedulerConsts[0], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[0]["num_steps"];
} else {
scheduler = new TVMPNDMScheduler(
this.schedulerConsts[1], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[1]["num_steps"];
}
const totalNumSteps = unetNumSteps + 2;

if (progressCallback !== undefined) {
progressCallback("clip", 0, 1);
progressCallback("clip", 0, 1, totalNumSteps);
}

const embeddings = this.tvm.withNewScope(() => {
Expand Down Expand Up @@ -229,13 +329,12 @@ class StableDiffusionPipeline {
});
await this.device.sync();
}
const numSteps = 50;
vaeCycle = vaeCycle == -1 ? numSteps: vaeCycle;
vaeCycle = vaeCycle == -1 ? unetNumSteps : vaeCycle;
let lastSync = undefined;

for (let counter = 0; counter < numSteps; ++counter) {
for (let counter = 0; counter < unetNumSteps; ++counter) {
if (progressCallback !== undefined) {
progressCallback("unet", counter, numSteps);
progressCallback("unet", counter, unetNumSteps, totalNumSteps);
}
const timestep = scheduler.timestep[counter];
// recycle noisePred, track latents manually
Expand All @@ -258,8 +357,8 @@ class StableDiffusionPipeline {

// Optionally, we can draw intermediate result of VAE.
if ((counter + 1) % vaeCycle == 0 &&
(counter + 1) != numSteps &&
counter >= beginRenderVae) {
(counter + 1) != unetNumSteps &&
counter >= beginRenderVae) {
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
this.tvm.showImage(this.imageToRGBA(image));
Expand All @@ -273,7 +372,7 @@ class StableDiffusionPipeline {
// Stage 2: VAE and draw image
//-----------------------------
if (progressCallback !== undefined) {
progressCallback("vae", 0, 1);
progressCallback("vae", 0, 1, totalNumSteps);
}
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
Expand All @@ -282,7 +381,7 @@ class StableDiffusionPipeline {
latents.dispose();
await this.device.sync();
if (progressCallback !== undefined) {
progressCallback("vae", 1, 1);
progressCallback("vae", 1, 1, totalNumSteps);
}
}

Expand Down Expand Up @@ -314,7 +413,7 @@ class StableDiffusionInstance {
}

if (document.getElementById("log") !== undefined) {
this.logger = function(message) {
this.logger = function (message) {
console.log(message);
const d = document.createElement("div");
d.innerHTML = message;
Expand Down Expand Up @@ -346,10 +445,10 @@ class StableDiffusionInstance {
} else {
document.getElementById(
"gpu-tracker-label").innerHTML = "This browser env do not support WebGPU";
this.reset();
throw Error("This browser env do not support WebGPU");
this.reset();
throw Error("This browser env do not support WebGPU");
}
} catch(err) {
} catch (err) {
document.getElementById("gpu-tracker-label").innerHTML = (
"Find an error initializing the WebGPU device " + err.toString()
);
Expand Down Expand Up @@ -381,46 +480,48 @@ class StableDiffusionInstance {
throw Error("asyncInitTVM is not called");
}
if (this.pipeline !== undefined) return;
const schedulerConst = await(await fetch(schedulerConstUrl)).json();
var schedulerConst = []
for (let i = 0; i < schedulerConstUrl.length; ++i) {
schedulerConst.push(await (await fetch(schedulerConstUrl[i])).json())
}
const tokenizer = await tvmjsGlobalEnv.getTokenizer(tokenizerName);
this.pipeline = this.tvm.withNewScope(() => {
return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata);
});
}

/**
* Async intitialize config
* Async initialize config
*/
async #asyncInitConfig() {
if (this.config !== undefined) return;
this.config = await(await fetch("stable-diffusion-config.json")).json();
this.config = await (await fetch("stable-diffusion-config.json")).json();
}

/**
* Function to create progress callback tracker.
* @returns A progress callback tracker.
*/
#getProgressCallback() {
#getProgressCallback() {
const tstart = performance.now();
function progressCallback(stage, counter, numSteps) {
const totalSteps = 50 + 2;
function progressCallback(stage, counter, numSteps, totalNumSteps) {
const timeElapsed = (performance.now() - tstart) / 1000;
let text = "Generating ... at stage " + stage;
if (stage == "unet") {
counter += 1;
text += " step [" + counter + "/" + numSteps + "]"
}
if (stage == "vae") {
counter += 51;
counter = totalNumSteps;
}
text += ", " + Math.ceil(timeElapsed) + " secs elapsed.";
document.getElementById("progress-tracker-label").innerHTML = text;
document.getElementById("progress-tracker-progress").value = (counter / totalSteps) * 100;
document.getElementById("progress-tracker-progress").value = (counter / totalNumSteps) * 100;
}
return progressCallback;
}

/**
/**
* Async initialize instance.
*/
async asyncInit() {
Expand All @@ -442,11 +543,11 @@ class StableDiffusionInstance {
this.tvm = tvmInstance;

this.tvm.beginScope();
this.tvm.registerAsyncServerFunc("generate", async (prompt, vaeCycle) => {
this.tvm.registerAsyncServerFunc("generate", async (prompt, schedulerId, vaeCycle) => {
document.getElementById("inputPrompt").value = prompt;
const negPrompt = "";
document.getElementById("negativePrompt").value = "";
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), vaeCycle);
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
});
this.tvm.registerAsyncServerFunc("clearCanvas", async () => {
this.tvm.clearCanvas();
Expand All @@ -470,8 +571,9 @@ class StableDiffusionInstance {
await this.asyncInit();
const prompt = document.getElementById("inputPrompt").value;
const negPrompt = document.getElementById("negativePrompt").value;
const vaeCycle =document.getElementById("vaeCycle").value;
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), vaeCycle);
const schedulerId = document.getElementById("schedulerId").value;
const vaeCycle = document.getElementById("vaeCycle").value;
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
} catch (err) {
this.logger("Generate error, " + err.toString());
console.log(err.stack);
Expand All @@ -494,11 +596,11 @@ class StableDiffusionInstance {

localStableDiffusionInst = new StableDiffusionInstance();

tvmjsGlobalEnv.asyncOnGenerate = async function() {
tvmjsGlobalEnv.asyncOnGenerate = async function () {
await localStableDiffusionInst.generate();
};

tvmjsGlobalEnv.asyncOnRPCServerLoad = async function(tvm) {
tvmjsGlobalEnv.asyncOnRPCServerLoad = async function (tvm) {
const inst = new StableDiffusionInstance();
await inst.asyncInitOnRPCServerLoad(tvm);
};
2 changes: 2 additions & 0 deletions web_stable_diffusion/trace/scheduler_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def compute_const_dict() -> Dict[str, List[tvm.nd.NDArray]]:
list_model_output_denom_coeff.append(model_output_denom_coeff.item())

return {
"num_steps": len(timesteps),
"timesteps": timesteps,
"sample_coeff": list_sample_coeff,
"alpha_diff": list_alpha_diff,
Expand Down Expand Up @@ -313,6 +314,7 @@ def compute_const_dict() -> Dict[str, List[tvm.nd.NDArray]]:
list_c2.append(c2.item())

return {
"num_steps": len(timesteps),
"timesteps": timesteps,
"alpha": list_alpha,
"sigma": list_sigma,
Expand Down

0 comments on commit 23f4add

Please sign in to comment.