Skip to content

Commit

Permalink
Refactor LoRA (huggingface#3778)
Browse files Browse the repository at this point in the history
* refactor to support patching LoRA into T5

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings

* text encoder monkeypatch class method

* fix test

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
williamberman and patrickvonplaten authored Jul 9, 2023
1 parent 78922ed commit c2a28c3
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 377 deletions.
151 changes: 74 additions & 77 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import warnings
from pathlib import Path
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -50,7 +51,10 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.loaders import (
LoraLoaderMixin,
text_encoder_lora_state_dict,
)
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
Expand All @@ -60,7 +64,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds


def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors

attn_processors_state_dict = {}

for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter

return attn_processors_state_dict


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -833,6 +853,7 @@ def main(args):

# Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
Expand All @@ -850,35 +871,18 @@ def main(args):
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)

module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())

unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
del temp_pipeline
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Expand All @@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None

if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()

for model in models:
state_dict = model.state_dict()

if (
text_encoder_lora_layers is not None
and text_encoder_keys is not None
and state_dict.keys() == text_encoder_keys
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Expand All @@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir):
)

def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose
# so that the we don't accidentally override the LoRA layers of
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
unet_ = None
text_encoder_ = None

# load lora weights into models
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
while len(models) > 0:
model = models.pop()

# delete temporary pipeline and pop models
del temp_pipeline
for _ in range(len(models)):
models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand Down Expand Up @@ -965,9 +956,9 @@ def load_model_hook(models, input_dir):

# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_layers.parameters()
else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
Expand Down Expand Up @@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt):

# Prepare everything with our `accelerator`.
if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
Expand Down Expand Up @@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_layers.parameters()
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
Expand Down Expand Up @@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt):
pipeline_args = {"prompt": args.validation_prompt}

if args.validation_images is None:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
Expand All @@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_attn_processors_state_dict(unet)

if text_encoder is not None:
if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else:
text_encoder_lora_layers = None

LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
Expand Down
Loading

0 comments on commit c2a28c3

Please sign in to comment.