Skip to content

Commit

Permalink
feat: add clip_vision annotator, support non-image input
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikubill committed Mar 5, 2023
1 parent bff62ed commit 9acfb6d
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 59 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## sd-webui-controlnet
(WIP) WebUI extension for ControlNet
(WIP) WebUI extension for ControlNet and T2I-Adapter

This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required.

Expand Down
23 changes: 23 additions & 0 deletions annotator/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from transformers import CLIPProcessor, CLIPVisionModel
from modules import devices

version = 'openai/clip-vit-large-patch14'
clip_proc = None
clip_vision_model = None

def apply_clip(img):
global clip_proc, clip_vision_model

if clip_vision_model is None:
clip_proc = CLIPProcessor.from_pretrained(version)
clip_vision_model = CLIPVisionModel.from_pretrained(version)

clip_vision_model = clip_vision_model.to(devices.get_device_for("controlnet"))
style_for_clip = clip_proc(images=img, return_tensors="pt")['pixel_values']
style_feat = clip_vision_model(style_for_clip.to(devices.get_device_for("controlnet")))['last_hidden_state']
return style_feat

def unload_clip_model():
global clip_proc, clip_vision_model
if clip_vision_model is not None:
clip_vision_model.cpu()
2 changes: 1 addition & 1 deletion models/color_adapter_v14.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
target: tencentarc.t21_adapter
target: scripts.adapter.Adapter_light
params:
channels: [320, 640, 1280, 1280]
nums_rb: 4
Expand Down
2 changes: 1 addition & 1 deletion models/style_adapter_v14.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
target: tencentarc.t21_adapter
target: scripts.adapter.StyleAdapter
params:
width: 1024
context_dim: 768
Expand Down
6 changes: 6 additions & 0 deletions models/t2iadapter_color_sd14v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model:
target: scripts.adapter.Adapter_light
params:
channels: [320, 640, 1280, 1280]
nums_rb: 4
cin: 192
9 changes: 9 additions & 0 deletions models/t2iadapter_keypose_sd14v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
target: tencentarc.t21_adapter
params:
channels: [320, 640, 1280, 1280]
nums_rb: 2
ksize: 1
sk: true
cin: 192
use_conv: false
8 changes: 8 additions & 0 deletions models/t2iadapter_style_sd14v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
target: scripts.adapter.StyleAdapter
params:
width: 1024
context_dim: 768
num_head: 8
n_layes: 3
num_token: 8
25 changes: 17 additions & 8 deletions scripts/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
import importlib
from collections import OrderedDict

from omegaconf import OmegaConf
Expand Down Expand Up @@ -55,21 +56,28 @@ def get_node_name(name, parent_name):
if p != parent_name:
return False, ''
return True, name[len(parent_name):]


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)


class PlugableAdapter(nn.Module):
def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None:
super().__init__()
config = OmegaConf.load(config_path)
model = Adapter
try:
self.target = config.model.target
model = get_obj_from_str(config.model.target)
except ImportError:
pass

if (config.model.params.cin == 64 * 6):
config.model.params.cin = 192
self.control_model = Adapter_light(**config.model.params)
elif (config.model.params.cin == 64 * 7):
del config.model.params.cin
self.control_model = StyleAdapter(**config.model.params)
else:
self.control_model = Adapter(**config.model.params)
self.control_model = model(**config.model.params)
self.control_model.load_state_dict(state_dict)
self.lowvram = lowvram
self.control = None
Expand Down Expand Up @@ -312,6 +320,7 @@ def forward(self, x):
# x shape [N, HW+1, C]
style_embedding = self.style_embedding + torch.zeros(
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)

x = torch.cat([x, style_embedding], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
Expand Down
87 changes: 53 additions & 34 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def __init__(self) -> None:
"mlsd": mlsd,
"normal_map": midas_normal,
"openpose": openpose,
# "openpose_hand": openpose_hand,
"openpose_hand": openpose_hand,
"clip_vision": clip,
"pidinet": pidinet,
"scribble": simple_scribble,
"fake_scribble": fake_scribble,
Expand All @@ -191,6 +192,7 @@ def __init__(self) -> None:
"hed": unload_hed,
"fake_scribble": unload_hed,
"mlsd": unload_mlsd,
"clip": unload_clip,
"depth": unload_midas,
"depth_leres": unload_leres,
"normal_map": unload_midas,
Expand Down Expand Up @@ -532,6 +534,38 @@ def parse_remote_call(self, p, params, idx):

return (enabled, module, model, weight, image, scribble_mode, \
resize_mode, rgbbgr_mode, lowvram, pres, pthr_a, pthr_b, guidance_start, guidance_end, guess_mode), input_image

def detectmap_proc(self, module, rgbbgr_mode, resize_mode, h, w):
detected_map = HWC3(detected_map)
if module == "normal_map" or rgbbgr_mode:
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0
else:
control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0

control = rearrange(control, 'h w c -> c h w')
detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w')

if resize_mode == "Scale to Fit (Inner Fit)":
transform = Compose([
Resize(h if h<w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w)),
])
control = transform(control)
detected_map = transform(detected_map)
elif resize_mode == "Envelope (Outer Fit)":
transform = Compose([
Resize(h if h>w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w))
])
control = transform(control)
detected_map = transform(detected_map)
else:
control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control)
detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map)

# for log use
detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8)
return control, detected_map

def process(self, p, is_img2img=False, *args):
"""
Expand Down Expand Up @@ -652,43 +686,28 @@ def process(self, p, is_img2img=False, *args):
preprocessor = self.preprocessor[module]
h, w, bsz = p.height, p.width, p.batch_size
if pres > 64:
detected_map = preprocessor(input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b)
detected_map, is_image = preprocessor(input_image, res=pres, thr_a=pthr_a, thr_b=pthr_b)
else:
detected_map = preprocessor(input_image)

detected_map = HWC3(detected_map)
if module == "normal_map" or rgbbgr_mode:
control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(devices.get_device_for("controlnet")) / 255.0
else:
control = torch.from_numpy(detected_map.copy()).float().to(devices.get_device_for("controlnet")) / 255.0
detected_map, is_image = preprocessor(input_image)

control = rearrange(control, 'h w c -> c h w')
detected_map = rearrange(torch.from_numpy(detected_map), 'h w c -> c h w')

if resize_mode == "Scale to Fit (Inner Fit)":
transform = Compose([
Resize(h if h<w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w)),
])
control = transform(control)
detected_map = transform(detected_map)
elif resize_mode == "Envelope (Outer Fit)":
transform = Compose([
Resize(h if h>w else w, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=(h, w))
])
control = transform(control)
detected_map = transform(detected_map)
if is_image:
control, detected_map = self.detectmap_proc(detected_map, rgbbgr_mode, resize_mode, h, w)
detected_maps.append((detected_map, module))
else:
control = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(control)
detected_map = Resize((h,w), interpolation=InterpolationMode.BICUBIC)(detected_map)

# for log use
detected_map = rearrange(detected_map, 'c h w -> h w c').numpy().astype(np.uint8)
detected_maps.append((detected_map, module))
control = detected_map

# hint_cond, guess_mode, weight, guidance_stopped, stop_guidance_percent, advanced_weighting
forward_param = ControlParams(model_net, control, guess_mode, weight, False, guidance_start, guidance_end, None, isinstance(model_net, PlugableAdapter))
forward_param = ControlParams(
control_model=model_net,
hint_cond=control,
guess_mode=guess_mode,
weight=weight,
guidance_stopped=False,
start_guidance_percent=guidance_start,
stop_guidance_percent=guidance_end,
advanced_weighting=None,
is_adapter=isinstance(model_net, PlugableAdapter),
is_extra_cond=getattr(model_net, "target", "") == "scripts.adapter.StyleAdapter"
)
forward_params.append(forward_param)

self.latest_network = UnetHook(lowvram=hook_lowvram)
Expand Down
21 changes: 20 additions & 1 deletion scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(
start_guidance_percent,
stop_guidance_percent,
advanced_weighting,
is_adapter
is_adapter,
is_extra_cond
):
self.control_model = control_model
self.hint_cond = hint_cond
Expand All @@ -60,6 +61,7 @@ def __init__(
self.stop_guidance_percent = stop_guidance_percent
self.advanced_weighting = advanced_weighting
self.is_adapter = is_adapter
self.is_extra_cond = is_extra_cond


class UnetHook(nn.Module):
Expand Down Expand Up @@ -108,6 +110,7 @@ def cfg_based_adder(base, x, require_autocast, is_adapter=False):
def forward(self, x, timesteps=None, context=None, **kwargs):
total_control = [0.0] * 13
total_adapter = [0.0] * 4
total_extra_cond = torch.zeros([0, context.shape[-1]]).to(devices.get_device_for("controlnet"))
only_mid_control = outer.only_mid_control
require_inpaint_hijack = False

Expand Down Expand Up @@ -138,6 +141,9 @@ def forward(self, x, timesteps=None, context=None, **kwargs):

if outer.lowvram:
param.control_model.to("cpu")
if param.is_extra_cond:
total_extra_cond = torch.cat([total_extra_cond, control.clone().squeeze(0)]) #* param.weight
continue
if param.guess_mode:
if param.is_adapter:
# see https://github.com/Mikubill/sd-webui-controlnet/issues/269
Expand All @@ -153,6 +159,19 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
target[idx] += item

control = total_control
if len(total_extra_cond) > 0 and context.shape[0] % 2 == 0:
total_extra_cond = torch.repeat_interleave(total_extra_cond.unsqueeze(0), context.shape[0] // 2, dim=0)
if outer.is_vanilla_samplers:
uncond, cond = context.chunk(2)
cond = torch.cat([cond, total_extra_cond], dim=1)
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
context = torch.cat([uncond, cond], dim=0)
else:
cond, uncond = context.chunk(2)
cond = torch.cat([cond, total_extra_cond], dim=1)
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
context = torch.cat([cond, uncond], dim=0)

assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
hs = []
with th.no_grad():
Expand Down
Loading

0 comments on commit 9acfb6d

Please sign in to comment.