-
Notifications
You must be signed in to change notification settings - Fork 409
/
guidance.py
693 lines (573 loc) · 23.2 KB
/
guidance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
import torch
from typing import Literal, Optional
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
DIFFERENTIAL_SCALER = 0.2
# DIFFERENTIAL_SCALER = 0.25
def get_differential_mask(
conditional_latents: torch.Tensor,
unconditional_latents: torch.Tensor,
threshold: float = 0.2,
gradient: bool = False,
):
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
if gradient:
# wew need to scale it to 0-1
# differential_mask = differential_mask - differential_mask.min()
# differential_mask = differential_mask / differential_mask.max()
# add 0.2 threshold to both sides and clip
differential_mask = value_map(
differential_mask,
differential_mask.min(),
differential_mask.max(),
0 - threshold,
1 + threshold
)
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
else:
# make everything less than 0.2 be 0.0 and everything else be 1.0
differential_mask = torch.where(
differential_mask < threshold,
torch.zeros_like(differential_mask),
torch.ones_like(differential_mask)
)
return differential_mask
def get_targeted_polarity_loss(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
# noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
differential_scaler = DIFFERENTIAL_SCALER
unconditional_diff = (unconditional_latents - conditional_latents)
unconditional_diff_noise = unconditional_diff * differential_scaler
conditional_diff = (conditional_latents - unconditional_latents)
conditional_diff_noise = conditional_diff * differential_scaler
conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False)
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
#
baseline_conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
baseline_unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
conditional_noise = noise + unconditional_diff_noise
unconditional_noise = noise + conditional_diff_noise
conditional_noisy_latents = sd.add_noise(
conditional_latents,
conditional_noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
unconditional_noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# cat_baseline_noisy_latents = torch.cat(
# [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents],
# dim=0
# )
# Disable the LoRA network so we can predict parent network knowledge without it
# sd.network.is_active = False
# sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
# baseline_prediction = sd.predict_noise(
# latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(),
# conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
# timestep=cat_timesteps,
# guidance_scale=1.0,
# **pred_kwargs # adapter residuals in here
# ).detach()
# conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
# negative_network_weights = [weight * -1.0 for weight in network_weight_list]
# positive_network_weights = [weight * 1.0 for weight in network_weight_list]
# cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
sd.unet.train()
# sd.network.is_active = True
# sd.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# prediction = prediction - baseline_prediction
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
# pred_pos = pred_pos - conditional_baseline_prediction
# pred_neg = pred_neg - unconditional_baseline_prediction
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
conditional_noise.float(),
reduction="none"
)
pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
unconditional_noise.float(),
reduction="none"
)
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
loss = pred_loss + pred_neg_loss
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_direct_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
# target_noise,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# turn the LoRA network back on.
sd.unet.train()
# sd.network.is_active = True
# sd.network.multiplier = network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
prediction = sd.predict_noise(
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
unconditional_embeddings=unconditional_embeds,
timestep=torch.cat([timesteps, timesteps]),
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
guidance_scale = 1.1
guidance_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
guidance_loss = torch.nn.functional.mse_loss(
guidance_pred.float(),
noise.detach().float(),
reduction="none"
)
if mask_multiplier is not None:
guidance_loss = guidance_loss * mask_multiplier
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# loss = guidance_loss + masked_noise_loss
loss = guidance_loss
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# targeted
def get_targeted_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# Encode the unconditional image into latents
unconditional_noisy_latents = sd.noise_scheduler.add_noise(
unconditional_latents,
noise,
timesteps
)
conditional_noisy_latents = sd.noise_scheduler.add_noise(
conditional_latents,
noise,
timesteps
)
# was_network_active = self.network.is_active
sd.network.is_active = False
sd.unet.eval()
target_differential = unconditional_latents - conditional_latents
# scale our loss by the differential scaler
target_differential_abs = target_differential.abs()
target_differential_abs_min = \
target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
target_differential_abs_max = \
target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
min_guidance = 1.0
max_guidance = 2.0
differential_scaler = value_map(
target_differential_abs,
target_differential_abs_min,
target_differential_abs_max,
min_guidance,
max_guidance
).detach()
# With LoRA network bypassed, predict noise to get a baseline of what the network
# wants to do with the latents + noise. Pass our target latents here for the input.
target_unconditional = sd.predict_noise(
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
prior_prediction_loss = torch.nn.functional.mse_loss(
target_unconditional.float(),
noise.float(),
reduction="none"
).detach().clone()
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list]
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
# the opportunity to predict the differential + noise that was added to the latents.
prediction = sd.predict_noise(
latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(),
conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(),
timestep=torch.cat([timesteps, timesteps], dim=0),
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0)
conditional_loss = torch.nn.functional.mse_loss(
prediction_conditional.float(),
noise.float(),
reduction="none"
)
unconditional_loss = torch.nn.functional.mse_loss(
prediction_unconditional.float(),
noise.float(),
reduction="none"
)
positive_loss = torch.abs(
conditional_loss.float() - prior_prediction_loss.float(),
)
# scale our loss by the differential scaler
positive_loss = positive_loss * differential_scaler
positive_loss = positive_loss.mean([1, 2, 3])
polar_loss = torch.abs(
conditional_loss.float() - unconditional_loss.float(),
).mean([1, 2, 3])
positive_loss = positive_loss.mean() + polar_loss.mean()
positive_loss.backward()
# loss = positive_loss.detach() + negative_loss.detach()
loss = positive_loss.detach()
# add a grad so other backward does not fail
loss.requires_grad_(True)
# restore network
sd.network.multiplier = network_weight_list
return loss
def get_guided_loss_polarity(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
scaler=None,
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
noise = noise.to(device, dtype=dtype).detach()
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
target_pos = noise
target_neg = noise
if sd.is_flow_matching:
# set the timesteps for flow matching as linear since we will do weighing
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
target_pos = (noise - conditional_latents).detach()
target_neg = (noise - unconditional_latents).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
cat_network_weight_list = positive_network_weights + negative_network_weights
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = cat_network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
pred_loss = torch.nn.functional.mse_loss(
pred_pos.float(),
target_pos.float(),
reduction="none"
)
# pred_loss = pred_loss.mean([1, 2, 3])
pred_neg_loss = torch.nn.functional.mse_loss(
pred_neg.float(),
target_neg.float(),
reduction="none"
)
loss = pred_loss + pred_neg_loss
# if sd.is_flow_matching:
# timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
# loss = loss * timestep_weight
loss = loss.mean([1, 2, 3])
loss = loss.mean()
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_tnt(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
prior_pred: torch.Tensor = None,
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
noise = noise.to(device, dtype=dtype).detach()
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# turn the LoRA network back on.
sd.unet.train()
if sd.network is not None:
cat_network_weight_list = [weight for weight in network_weight_list * 2]
sd.network.multiplier = cat_network_weight_list
sd.network.is_active = True
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
this_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
noise.float(),
reduction="none"
)
that_loss = torch.nn.functional.mse_loss(
that_prediction.float(),
noise.float(),
reduction="none"
)
this_loss = this_loss.mean([1, 2, 3])
# negative loss on that
that_loss = -that_loss.mean([1, 2, 3])
with torch.no_grad():
# match that loss with this loss so it is not a negative value and same scale
that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss)
that_loss = that_loss * that_loss_scaler * 0.01
loss = this_loss + that_loss
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# this processes all guidance losses based on the batch information
def get_guidance_loss(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
scaler=None,
**kwargs
):
# TODO add others and process individual batch items separately
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == "targeted":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
return get_targeted_guidance_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
scaler=scaler,
**kwargs
)
elif guidance_type == "tnt":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_tnt(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
prior_pred=prior_pred,
**kwargs
)
elif guidance_type == "targeted_polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
return get_targeted_polarity_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "direct":
return get_direct_guidance_loss(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
**kwargs
)
else:
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")