forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoriginal_control_net.py
320 lines (256 loc) · 11.6 KB
/
original_control_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
from typing import List, NamedTuple, Any
import numpy as np
import cv2
import torch
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
import library.model_util as model_util
class ControlNetInfo(NamedTuple):
unet: Any
net: Any
prep: Any
weight: float
ratio: float
class ControlNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# make control model
self.control_model = torch.nn.Module()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
zero_convs = torch.nn.ModuleList()
for i, dim in enumerate(dims):
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
zero_convs.append(sub_list)
self.control_model.add_module("zero_convs", zero_convs)
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
dims = [16, 16, 32, 32, 96, 96, 256, 320]
strides = [1, 1, 2, 1, 2, 1, 2, 1]
prev_dim = 3
input_hint_block = torch.nn.Sequential()
for i, (dim, stride) in enumerate(zip(dims, strides)):
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
if i < len(dims) - 1:
input_hint_block.append(torch.nn.SiLU())
prev_dim = dim
self.control_model.add_module("input_hint_block", input_hint_block)
def load_control_net(v2, unet, model):
device = unet.device
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
else:
ctrl_sd_sd = torch.load(model, map_location='cpu')
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
for key in list(ctrl_unet_sd_sd.keys()):
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
zero_conv_sd = {}
for key in list(ctrl_sd_sd.keys()):
if key.startswith("control_"):
unet_key = "model.diffusion_" + key[len("control_"):]
if unet_key not in ctrl_unet_sd_sd: # zero conv
zero_conv_sd[key] = ctrl_sd_sd[key]
continue
if is_difference: # Transfer Control
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
else:
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
unet_config = model_util.create_unet_diffusers_config(v2)
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
return ctrl_unet, ctrl_net
def load_preprocess(prep_type: str):
if prep_type is None or prep_type.lower() == "none":
return None
if prep_type.startswith("canny"):
args = prep_type.split("_")
th1 = int(args[1]) if len(args) >= 2 else 63
th2 = int(args[2]) if len(args) >= 3 else 191
def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2)
return canny
print("Unsupported prep type:", prep_type)
return None
def preprocess_ctrl_net_hint_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
guided_hints = []
for i, cnet_info in enumerate(control_nets):
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
b_hints = []
if len(hints) == 1: # すべて同じ画像をhintとして使う
hint = hints[0]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints = [hint for _ in range(b_size)]
else:
for bi in range(b_size):
hint = hints[(bi * len(control_nets) + i) % len(hints)]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints.append(hint)
b_hints = torch.cat(b_hints, dim=0)
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
guided_hints.append(guided_hint)
return guided_hints
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
# ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets)
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
outs = [o * cnet_info.weight for o in outs]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
"""
# これはmergeのバージョン
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
for i in range(len(outs)):
outs[i] *= cnet_info.weight
cnet_outs_list.append(outs)
count = len(cnet_outs_list)
if count == 0:
return original_unet(sample, timestep, encoder_hidden_states)
# sum of controlnets
for i in range(1, count):
cnet_outs_list[0] += cnet_outs_list[i]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
"""
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
# copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if unet.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = unet.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=unet.dtype)
emb = unet.time_embedding(t_emb)
outs = [] # output of ControlNet
zc_idx = 0
# 2. pre-process
sample = unet.conv_in(sample)
if is_control_net:
sample += guided_hint
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
zc_idx += 1
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_control_net:
for rs in res_samples:
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
zc_idx += 1
down_block_res_samples += res_samples
# 4. mid
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
if is_control_net:
outs.append(control_net.control_model.middle_block_out[0](sample))
return outs
if not is_control_net:
sample += ctrl_outs.pop()
# 5. up
for i, upsample_block in enumerate(unet.up_blocks):
is_final_block = i == len(unet.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if not is_control_net and len(ctrl_outs) > 0:
res_samples = list(res_samples)
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
ctrl_outs = ctrl_outs[:-len(res_samples)]
for j in range(len(res_samples)):
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
res_samples = tuple(res_samples)
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = unet.conv_norm_out(sample)
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
return UNet2DConditionOutput(sample=sample)