Skip to content

Commit

Permalink
[Scheduler] Multistep DPM-solver for Python side (mlc-ai#2)
Browse files Browse the repository at this point in the history
This PR introduces the multistep DPM-solver tracing and the Python side
runtime. With this PR, we can now deploy the stable diffusion locally
with the multistep DPM-solver, which is set to generate an image with a
fixed number (20) of steps. By reducing the number of UNet iterations
from 50 to 20, the image generation is accelerated by 2.5x.

This PR also organizes the code around schedulers, so that the codebase
is more manageable and extensible when more possible schedulers join in
the future.

The web deployment scripts get broken by this PR. They will be fixed
soon when the multistep DPM-solver for web side is introduced.
  • Loading branch information
MasterJH5574 authored Mar 12, 2023
1 parent a07f441 commit 827f615
Show file tree
Hide file tree
Showing 7 changed files with 484 additions and 215 deletions.
39 changes: 24 additions & 15 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def _parse_args():
return parsed


def debug_dump(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
dump_path = os.path.join(args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")


def trace_models(
device_str: str,
) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]:
Expand All @@ -46,10 +56,15 @@ def trace_models(
vae = trace.vae_to_image(pipe)
concat_embeddings = trace.concat_embeddings()
image_to_rgba = trace.image_to_rgba()
scheduler_steps = trace.scheduler_steps()
schedulers = [scheduler.scheduler_steps() for scheduler in trace.schedulers]

mod = utils.merge_irmodules(
clip, unet, vae, concat_embeddings, image_to_rgba, scheduler_steps
clip,
unet,
vae,
concat_embeddings,
image_to_rgba,
*schedulers,
)
return relax.frontend.detach_params(mod)

Expand All @@ -59,7 +74,11 @@ def legalize_and_lift_params(
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
model_names = ["clip", "unet", "vae"]
scheduler_func_names = [f"scheduler_step_{i}" for i in range(5)]
scheduler_func_names = [
name
for scheduler in trace.schedulers
for name in scheduler.scheduler_steps_func_names()
]
entry_funcs = (
model_names + scheduler_func_names + ["image_to_rgba", "concat_embeddings"]
)
Expand All @@ -68,28 +87,18 @@ def legalize_and_lift_params(
mod = relax.transform.RemoveUnusedFunctions(entry_funcs)(mod)
mod = relax.transform.LiftTransformParams()(mod)

debug_dump(mod_deploy, "mod_lift_params.py", args)

mod_transform, mod_deploy = utils.split_transform_deploy_mod(
mod, model_names, entry_funcs
)

debug_dump(mod_transform, "mod_lift_params.py", args)

trace.compute_save_scheduler_consts(args.artifact_path)
new_params = utils.transform_params(mod_transform, model_params)
utils.save_params(new_params, args.artifact_path)
return mod_deploy


def debug_dump(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
dump_path = os.path.join(args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")


def build(mod: tvm.IRModule, args: Dict) -> None:
from tvm import meta_schedule as ms

Expand Down
84 changes: 25 additions & 59 deletions deploy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Type

import argparse
import time
Expand All @@ -7,11 +7,10 @@
from transformers import CLIPTokenizer

import web_stable_diffusion.utils as utils
import web_stable_diffusion.runtime as runtime

import os
import json
import tvm
import numpy as np
from tvm import relax

from tqdm import tqdm
Expand All @@ -27,6 +26,12 @@ def _parse_args():
"--prompt", type=str, default="A photo of an astronaut riding a horse on mars."
)
args.add_argument("--negative-prompt", type=str, default="")
args.add_argument(
"--scheduler",
type=str,
choices=[scheduler.scheduler_name for scheduler in runtime.schedulers],
default=runtime.DPMSolverMultistepScheduler.scheduler_name,
)
parsed = args.parse_args()
if parsed.device_name == "auto":
if tvm.cuda().exist:
Expand All @@ -38,64 +43,12 @@ def _parse_args():
return parsed


class PNDMScheduler:
def __init__(self, artifact_path: str, device) -> None:
with open(f"{artifact_path}/scheduler_consts.json", "r") as file:
jsoncontent = file.read()
scheduler_consts = json.loads(jsoncontent)

def f_convert(data, dtype):
return [tvm.nd.array(np.array(t, dtype=dtype), device) for t in data]

self.timesteps = f_convert(scheduler_consts["timesteps"], "int32")
self.sample_coeff = f_convert(scheduler_consts["sample_coeff"], "float32")
self.alpha_diff = f_convert(scheduler_consts["alpha_diff"], "float32")
self.model_output_denom_coeff = f_convert(
scheduler_consts["model_output_denom_coeff"], "float32"
)

self.ets: List[tvm.nd.NDArray] = [
tvm.nd.empty((1, 4, 64, 64), "float32", device)
] * 4
self.cur_sample: tvm.nd.NDArray

def step(
self,
vm: relax.VirtualMachine,
model_output: tvm.nd.NDArray,
sample: tvm.nd.NDArray,
counter: int,
) -> tvm.nd.NDArray:
if counter != 1:
self.ets = self.ets[-3:]
self.ets.append(model_output)

if counter == 0:
self.cur_sample = sample
elif counter == 1:
sample = self.cur_sample

prev_latents = vm[f"scheduler_step_{min(counter, 4)}"](
sample,
model_output,
self.sample_coeff[counter],
self.alpha_diff[counter],
self.model_output_denom_coeff[counter],
self.ets[0],
self.ets[1],
self.ets[2],
self.ets[3],
)

return prev_latents


class TVMSDPipeline:
def __init__(
self,
vm: relax.VirtualMachine,
tokenizer: CLIPTokenizer,
scheduler: PNDMScheduler,
scheduler: runtime.Scheduler,
tvm_device,
param_dict,
debug_dump_dir,
Expand All @@ -119,12 +72,13 @@ def wrapped_f(*args):
self.debug_dump_dir = debug_dump_dir

def debug_dump(self, name, arr):
import numpy as np

if self.debug_dump_dir:
np.save(f"{self.debug_dump_dir}/{name}.npy", arr.numpy())

def __call__(self, prompt: str, negative_prompt: str = ""):
# height = width = 512
num_inference_steps = 50

list_text_embeddings = []
for text in [negative_prompt, prompt]:
Expand Down Expand Up @@ -153,7 +107,7 @@ def __call__(self, prompt: str, negative_prompt: str = ""):
)
latents = tvm.nd.array(latents.numpy(), self.tvm_device)

for i in tqdm(range(num_inference_steps)):
for i in tqdm(range(len(self.scheduler.timesteps))):
t = self.scheduler.timesteps[i]
self.debug_dump(f"unet_input_{i}", latents)
self.debug_dump(f"timestep_{i}", t)
Expand All @@ -168,6 +122,18 @@ def __call__(self, prompt: str, negative_prompt: str = ""):
return Image.fromarray(image.numpy().view("uint8").reshape(512, 512, 4))


def get_scheduler_type(scheduler_name: str) -> Type[runtime.Scheduler]:
for scheduler in runtime.schedulers:
if scheduler_name == scheduler.scheduler_name:
return scheduler

scheduler_names = [scheduler.scheduler_name for scheduler in runtime.schedulers]
raise ValueError(
f'"{scheduler_name}" is an unsupported scheduler name. The list of '
f"supported scheduler names is {scheduler_names}"
)


def deploy_to_pipeline(args) -> None:
device = tvm.device(args.device_name)
const_params_dict = utils.load_params(args.artifact_path, device)
Expand All @@ -183,7 +149,7 @@ def deploy_to_pipeline(args) -> None:
pipe = TVMSDPipeline(
vm=vm,
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14"),
scheduler=PNDMScheduler(args.artifact_path, device),
scheduler=get_scheduler_type(args.scheduler)(args.artifact_path, device),
tvm_device=device,
param_dict=const_params_dict,
debug_dump_dir=debug_dump_dir,
Expand Down
1 change: 1 addition & 0 deletions web_stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import runtime
from . import trace
from . import utils
1 change: 1 addition & 0 deletions web_stable_diffusion/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .scheduler_runtime import *
129 changes: 129 additions & 0 deletions web_stable_diffusion/runtime/scheduler_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import ClassVar, List, Type

import json
import numpy as np

import tvm
from tvm import relax


class Scheduler:
scheduler_name: ClassVar[str]
timesteps: List[tvm.nd.NDArray]

def __init__(self, artifact_path: str, device) -> None:
raise NotImplementedError()

def step(
self,
vm: relax.VirtualMachine,
model_output: tvm.nd.NDArray,
sample: tvm.nd.NDArray,
counter: int,
) -> tvm.nd.NDArray:
raise NotImplementedError()


class PNDMScheduler(Scheduler):
scheduler_name = "pndm"

def __init__(self, artifact_path: str, device) -> None:
with open(f"{artifact_path}/scheduler_pndm_consts.json", "r") as file:
jsoncontent = file.read()
scheduler_consts = json.loads(jsoncontent)

def f_convert(data, dtype):
return [tvm.nd.array(np.array(t, dtype=dtype), device) for t in data]

self.timesteps = f_convert(scheduler_consts["timesteps"], "int32")
self.sample_coeff = f_convert(scheduler_consts["sample_coeff"], "float32")
self.alpha_diff = f_convert(scheduler_consts["alpha_diff"], "float32")
self.model_output_denom_coeff = f_convert(
scheduler_consts["model_output_denom_coeff"], "float32"
)

self.ets: List[tvm.nd.NDArray] = [
tvm.nd.empty((1, 4, 64, 64), "float32", device)
] * 4
self.cur_sample: tvm.nd.NDArray

def step(
self,
vm: relax.VirtualMachine,
model_output: tvm.nd.NDArray,
sample: tvm.nd.NDArray,
counter: int,
) -> tvm.nd.NDArray:
if counter != 1:
self.ets = self.ets[-3:]
self.ets.append(model_output)

if counter == 0:
self.cur_sample = sample
elif counter == 1:
sample = self.cur_sample

prev_latents = vm[f"pndm_scheduler_step_{min(counter, 4)}"](
sample,
model_output,
self.sample_coeff[counter],
self.alpha_diff[counter],
self.model_output_denom_coeff[counter],
self.ets[0],
self.ets[1],
self.ets[2],
self.ets[3],
)

return prev_latents


class DPMSolverMultistepScheduler(Scheduler):
scheduler_name = "multistep-dpm-solver"

def __init__(self, artifact_path: str, device) -> None:
with open(
f"{artifact_path}/scheduler_dpm_solver_multistep_consts.json", "r"
) as file:
jsoncontent = file.read()
scheduler_consts = json.loads(jsoncontent)

def f_convert(data, dtype):
return [tvm.nd.array(np.array(t, dtype=dtype), device) for t in data]

self.timesteps = f_convert(scheduler_consts["timesteps"], "int32")
self.alpha = f_convert(scheduler_consts["alpha"], "float32")
self.sigma = f_convert(scheduler_consts["sigma"], "float32")
self.c0 = f_convert(scheduler_consts["c0"], "float32")
self.c1 = f_convert(scheduler_consts["c1"], "float32")
self.c2 = f_convert(scheduler_consts["c2"], "float32")

self.last_model_output: tvm.nd.NDArray = tvm.nd.empty(
(1, 4, 64, 64), "float32", device
)

def step(
self,
vm: relax.VirtualMachine,
model_output: tvm.nd.NDArray,
sample: tvm.nd.NDArray,
counter: int,
) -> tvm.nd.NDArray:
model_output = vm["dpm_solver_multistep_scheduler_convert_model_output"](
sample, model_output, self.alpha[counter], self.sigma[counter]
)
prev_latents = vm["dpm_solver_multistep_scheduler_step"](
sample,
model_output,
self.last_model_output,
self.c0[counter],
self.c1[counter],
self.c2[counter],
)
self.last_model_output = model_output
return prev_latents


########################################################################

schedulers: List[Type[Scheduler]] = [DPMSolverMultistepScheduler, PNDMScheduler]
3 changes: 2 additions & 1 deletion web_stable_diffusion/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .model_trace import *
from .scheduler_trace import compute_save_scheduler_consts, scheduler_steps
from .scheduler_trace import Scheduler, DPMSolverMultistepScheduler, PNDMScheduler
from .scheduler_trace import compute_save_scheduler_consts, schedulers
Loading

0 comments on commit 827f615

Please sign in to comment.