Skip to content

Commit

Permalink
[Examples] fix: network_alpha -> network_alphas (huggingface#4572)
Browse files Browse the repository at this point in the history
network_alpha
  • Loading branch information
sayakpaul authored Aug 11, 2023
1 parent 796c015 commit d5983a6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,13 +722,13 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down

0 comments on commit d5983a6

Please sign in to comment.