Skip to content

Commit

Permalink
[SD] Add Stencil feature to SD pipeline (nod-ai#1111)
Browse files Browse the repository at this point in the history
* [WIP] Add ControlNet to SD pipeline

-- This commit adds ControlNet to SD pipeline.

Signed-off-by: Abhishek Varma <[email protected]>

* [SD] Add ControlNet to img2img + fix bug for img2img scheduler

-- This commit adds ControlNet execution to img2img.
-- It restructures the addition of ControlNet variants.
-- It also fixes scheduler selecting bug for img2img pipeline.

Signed-off-by: Abhishek Varma <[email protected]>

* add shark models for stencilSD

* Add Stencil controlled SD in img2img pipeline (nod-ai#1106)

* use shark stencil modules

* adjust diffusers change

* modify to use pipeline

* remove control from unet

* pump stencils through unet

* complete integration in img2img

* fix lint and comments

* [SD] Add ControlNet pipeline + integrate with WebUI + add compiled flow execution

-- This commit creates a dedicated SD pipeline for ControlNet.
-- Integrates it with img2img WebUI.
-- Integrates the compiled execution flow for ControlNet.

Signed-off-by: Abhishek Varma <[email protected]>

* [SD] Stencil execution

* Remove integration setup

* [SD] Fix args.use_stencil overriding bug + vmfb caching issue

-- This commit fixes args.use_stencil overriding issue which caused
   img2img pipeline to pick wrong set of modules.
-- It also fixes vmfb caching issue to speed up the loading time
   and pick right set of modules based on a mask.

Signed-off-by: Abhishek Varma <[email protected]>

---------

Signed-off-by: Abhishek Varma <[email protected]>
Co-authored-by: Abhishek Varma <[email protected]>
Co-authored-by: PhaneeshB <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2023
1 parent f095745 commit be3cdec
Show file tree
Hide file tree
Showing 15 changed files with 840 additions and 66 deletions.
137 changes: 99 additions & 38 deletions apps/stable_diffusion/scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
StencilPipeline,
get_schedulers,
set_init_device_flags,
utils,
Expand All @@ -24,6 +25,7 @@ class Config:
height: int
width: int
device: str
use_stencil: str


img2img_obj = None
Expand All @@ -50,6 +52,7 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
Expand Down Expand Up @@ -92,8 +95,24 @@ def img2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png

use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
)
args.scheduler = "PNDM"
else:
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
Expand All @@ -103,10 +122,10 @@ def img2img_inf(
height,
width,
device,
use_stencil,
)
if not img2img_obj or config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
Expand All @@ -123,21 +142,40 @@ def img2img_inf(
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
if use_stencil is not None:
args.use_tuned = False
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)

img2img_obj.scheduler = schedulers[scheduler]

Expand Down Expand Up @@ -165,6 +203,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
save_output_img(out_imgs[0], img_seed, extra_info)
generated_imgs.extend(out_imgs)
Expand Down Expand Up @@ -195,11 +234,11 @@ def img2img_inf(
# When the models get uploaded, it should be default to False.
args.import_mlir = True

dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
if args.scheduler != "PNDM":
use_stencil = args.use_stencil
if use_stencil:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
elif args.scheduler != "PNDM":
if "Shark" in args.scheduler:
print(
f"SharkEulerDiscrete scheduler not supported. Switching to PNDM scheduler"
Expand All @@ -209,28 +248,49 @@ def img2img_inf(
sys.exit(
"Img2Img works best with PNDM scheduler. Other schedulers are not supported yet."
)
cpu_scheduling = not args.scheduler.startswith("Shark")
dtype = torch.float32 if args.precision == "fp32" else torch.half
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)

scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path).convert("RGB")
seed = utils.sanitize_seed(args.seed)

# Adjust for height and width based on model

img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
if use_stencil:
img2img_obj = StencilPipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
)
else:
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)

start_time = time.time()
generated_imgs = img2img_obj.generate_images(
Expand All @@ -248,6 +308,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
StencilPipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers
Loading

0 comments on commit be3cdec

Please sign in to comment.