Skip to content

Commit

Permalink
Fix LCM Stable Diffusion distillation bug related to parsing unet_tim…
Browse files Browse the repository at this point in the history
…e_cond_proj_dim (huggingface#5893)

* Fix bug related to parsing unet_time_cond_proj_dim.

* Fix analogous bug in the SD-XL LCM distillation script.
  • Loading branch information
dg845 authored Nov 27, 2023
1 parent c079cae commit 07eac4d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
11 changes: 10 additions & 1 deletion examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
Expand Down Expand Up @@ -1138,7 +1147,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
Expand Down
14 changes: 12 additions & 2 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
Expand Down Expand Up @@ -1233,6 +1242,7 @@ def compute_embeddings(

# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)

Expand All @@ -1243,7 +1253,7 @@ def compute_embeddings(
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
Expand Down Expand Up @@ -1308,7 +1318,7 @@ def compute_embeddings(
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
Expand Down

0 comments on commit 07eac4d

Please sign in to comment.