Skip to content

Commit 0302446

Browse files
UmerHAsayakpaul
andauthored
Implements Blockwise lora (huggingface#7352)
* Initial commit * Implemented block lora - implemented block lora - updated docs - added tests * Finishing up * Reverted unrelated changes made by make style * Fixed typo * Fixed bug + Made text_encoder_2 scalable * Integrated some review feedback * Incorporated review feedback * Fix tests * Made every module configurable * Adapter to new lora test structure * Final cleanup * Some more final fixes - Included examples in `using_peft_for_inference.md` - Added hint that only attns are scaled - Removed NoneTypes - Added test to check mismatching lens of adapter names / weights raise error * Update using_peft_for_inference.md * Update using_peft_for_inference.md * Make style, quality, fix-copies * Updated tutorial;Warning if scale/adapter mismatch * floats are forwarded as-is; changed tutorial scale * make style, quality, fix-copies * Fixed typo in tutorial * Moved some warnings into `lora_loader_utils.py` * Moved scale/lora mismatch warnings back * Integrated final review suggestions * Empty commit to trigger CI * Reverted emoty commit to trigger CI --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 4d39b74 commit 0302446

File tree

7 files changed

+553
-21
lines changed

7 files changed

+553
-21
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,62 @@ image
133133

134134
![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)
135135

136+
### Customize adapters strength
137+
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
138+
139+
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
140+
```python
141+
pipe.enable_lora() # enable lora again, after we disabled it above
142+
prompt = "toy_face of a hacker with a hoodie, pixel art"
143+
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
144+
pipe.set_adapters("pixel", adapter_weight_scales)
145+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
146+
image
147+
```
148+
149+
![block-lora-text-and-down](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_down.png)
150+
151+
Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
152+
```python
153+
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
154+
pipe.set_adapters("pixel", adapter_weight_scales)
155+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
156+
image
157+
```
158+
159+
![block-lora-text-and-mid](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mid.png)
160+
161+
```python
162+
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
163+
pipe.set_adapters("pixel", adapter_weight_scales)
164+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
165+
image
166+
```
167+
168+
![block-lora-text-and-up](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_up.png)
169+
170+
Looks cool!
171+
172+
This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
173+
```python
174+
adapter_weight_scales_toy = 0.5
175+
adapter_weight_scales_pixel = {
176+
"unet": {
177+
"down": 0.9, # all transformers in the down-part will use scale 0.9
178+
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
179+
"up": {
180+
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
181+
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
182+
}
183+
}
184+
}
185+
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
186+
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
187+
image
188+
```
189+
190+
![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)
191+
136192
## Manage active adapters
137193

138194
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,43 @@ image
153153
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
154154
</div>
155155

156-
<Tip>
157-
158-
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
159-
160-
</Tip>
161-
162156
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
163157

164158
```py
165159
pipeline.unload_lora_weights()
166160
```
167161

162+
### Adjust LoRA weight scale
163+
164+
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
165+
166+
For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
167+
```python
168+
pipe = ... # create pipeline
169+
pipe.load_lora_weights(..., adapter_name="my_adapter")
170+
scales = {
171+
"text_encoder": 0.5,
172+
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
173+
"unet": {
174+
"down": 0.9, # all transformers in the down-part will use scale 0.9
175+
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
176+
"up": {
177+
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
178+
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
179+
}
180+
}
181+
}
182+
pipe.set_adapters("my_adapter", scales)
183+
```
184+
185+
This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
186+
187+
<Tip warning={true}>
188+
189+
Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
190+
191+
</Tip>
192+
168193
### Kohya and TheLastBen
169194

170195
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.

src/diffusers/loaders/lora.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import inspect
1516
import os
1617
from pathlib import Path
@@ -985,7 +986,7 @@ def set_adapters_for_text_encoder(
985986
self,
986987
adapter_names: Union[List[str], str],
987988
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
988-
text_encoder_weights: List[float] = None,
989+
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
989990
):
990991
"""
991992
Sets the adapter layers for the text encoder.
@@ -1003,15 +1004,20 @@ def set_adapters_for_text_encoder(
10031004
raise ValueError("PEFT backend is required for this method.")
10041005

10051006
def process_weights(adapter_names, weights):
1006-
if weights is None:
1007-
weights = [1.0] * len(adapter_names)
1008-
elif isinstance(weights, float):
1009-
weights = [weights]
1007+
# Expand weights into a list, one entry per adapter
1008+
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
1009+
if not isinstance(weights, list):
1010+
weights = [weights] * len(adapter_names)
10101011

10111012
if len(adapter_names) != len(weights):
10121013
raise ValueError(
10131014
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
10141015
)
1016+
1017+
# Set None values to default of 1.0
1018+
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
1019+
weights = [w if w is not None else 1.0 for w in weights]
1020+
10151021
return weights
10161022

10171023
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
@@ -1059,17 +1065,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
10591065
def set_adapters(
10601066
self,
10611067
adapter_names: Union[List[str], str],
1062-
adapter_weights: Optional[List[float]] = None,
1068+
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
10631069
):
1070+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
1071+
1072+
adapter_weights = copy.deepcopy(adapter_weights)
1073+
1074+
# Expand weights into a list, one entry per adapter
1075+
if not isinstance(adapter_weights, list):
1076+
adapter_weights = [adapter_weights] * len(adapter_names)
1077+
1078+
if len(adapter_names) != len(adapter_weights):
1079+
raise ValueError(
1080+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
1081+
)
1082+
1083+
# Decompose weights into weights for unet, text_encoder and text_encoder_2
1084+
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
1085+
1086+
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
1087+
all_adapters = {
1088+
adapter for adapters in list_adapters.values() for adapter in adapters
1089+
} # eg ["adapter1", "adapter2"]
1090+
invert_list_adapters = {
1091+
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
1092+
for adapter in all_adapters
1093+
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
1094+
1095+
for adapter_name, weights in zip(adapter_names, adapter_weights):
1096+
if isinstance(weights, dict):
1097+
unet_lora_weight = weights.pop("unet", None)
1098+
text_encoder_lora_weight = weights.pop("text_encoder", None)
1099+
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
1100+
1101+
if len(weights) > 0:
1102+
raise ValueError(
1103+
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
1104+
)
1105+
1106+
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
1107+
logger.warning(
1108+
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
1109+
)
1110+
1111+
# warn if adapter doesn't have parts specified by adapter_weights
1112+
for part_weight, part_name in zip(
1113+
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
1114+
["uent", "text_encoder", "text_encoder_2"],
1115+
):
1116+
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
1117+
logger.warning(
1118+
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
1119+
)
1120+
1121+
else:
1122+
unet_lora_weight = weights
1123+
text_encoder_lora_weight = weights
1124+
text_encoder_2_lora_weight = weights
1125+
1126+
unet_lora_weights.append(unet_lora_weight)
1127+
text_encoder_lora_weights.append(text_encoder_lora_weight)
1128+
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
1129+
10641130
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
10651131
# Handle the UNET
1066-
unet.set_adapters(adapter_names, adapter_weights)
1132+
unet.set_adapters(adapter_names, unet_lora_weights)
10671133

10681134
# Handle the Text Encoder
10691135
if hasattr(self, "text_encoder"):
1070-
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
1136+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
10711137
if hasattr(self, "text_encoder_2"):
1072-
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
1138+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
10731139

10741140
def disable_lora(self):
10751141
if not USE_PEFT_BACKEND:

src/diffusers/loaders/unet.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
infer_stable_cascade_single_file_config,
4848
load_single_file_model_checkpoint,
4949
)
50+
from .unet_loader_utils import _maybe_expand_lora_scales
5051
from .utils import AttnProcsLayers
5152

5253

@@ -564,7 +565,7 @@ def _unfuse_lora_apply(self, module):
564565
def set_adapters(
565566
self,
566567
adapter_names: Union[List[str], str],
567-
weights: Optional[Union[List[float], float]] = None,
568+
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
568569
):
569570
"""
570571
Set the currently active adapters for use in the UNet.
@@ -597,16 +598,23 @@ def set_adapters(
597598

598599
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
599600

600-
if weights is None:
601-
weights = [1.0] * len(adapter_names)
602-
elif isinstance(weights, float):
601+
# Expand weights into a list, one entry per adapter
602+
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
603+
if not isinstance(weights, list):
603604
weights = [weights] * len(adapter_names)
604605

605606
if len(adapter_names) != len(weights):
606607
raise ValueError(
607608
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
608609
)
609610

611+
# Set None values to default of 1.0
612+
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
613+
weights = [w if w is not None else 1.0 for w in weights]
614+
615+
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
616+
weights = _maybe_expand_lora_scales(self, weights)
617+
610618
set_weights_and_activate_adapters(self, adapter_names, weights)
611619

612620
def disable_lora(self):

0 commit comments

Comments
 (0)