forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[research_projects] add flux training script with quantization (huggi…
…ngface#9754) * add flux training script with quantization * remove exclamation
- Loading branch information
Showing
5 changed files
with
1,496 additions
and
0 deletions.
There are no files selected for viewing
166 changes: 166 additions & 0 deletions
166
examples/research_projects/flux_lora_quantization/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
## LoRA fine-tuning Flux.1 Dev with quantization | ||
|
||
> [!NOTE] | ||
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further. | ||
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow: | ||
|
||
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file. | ||
* `train_dreambooth_lora_flux_miniature.py` takes care of training: | ||
* Since we already precomputed the text embeddings, we don't load the text encoders. | ||
* We load the VAE and use it to precompute the image latents and we then delete it. | ||
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training. | ||
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision. | ||
* Train! | ||
|
||
To run training in a memory-optimized manner, we additionally use: | ||
|
||
* 8Bit Adam | ||
* Gradient checkpointing | ||
|
||
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow. | ||
|
||
## Training | ||
|
||
Ensure you have installed the required libraries: | ||
|
||
```bash | ||
pip install -U transformers accelerate bitsandbytes peft datasets | ||
pip install git+https://github.com/huggingface/diffusers -U | ||
``` | ||
|
||
Now, compute the text embeddings: | ||
|
||
```bash | ||
python compute_embeddings.py | ||
``` | ||
|
||
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model: | ||
|
||
```bash | ||
huggingface-cli | ||
``` | ||
|
||
Then launch: | ||
|
||
```bash | ||
accelerate launch --config_file=accelerate.yaml \ | ||
train_dreambooth_lora_flux_miniature.py \ | ||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ | ||
--data_df_path="embeddings.parquet" \ | ||
--output_dir="yarn_art_lora_flux_nf4" \ | ||
--mixed_precision="fp16" \ | ||
--use_8bit_adam \ | ||
--weighting_scheme="none" \ | ||
--resolution=1024 \ | ||
--train_batch_size=1 \ | ||
--repeats=1 \ | ||
--learning_rate=1e-4 \ | ||
--guidance_scale=1 \ | ||
--report_to="wandb" \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--cache_latents \ | ||
--rank=4 \ | ||
--max_train_steps=700 \ | ||
--seed="0" | ||
``` | ||
|
||
We can direcly pass a quantized checkpoint path, too: | ||
|
||
```diff | ||
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg" | ||
``` | ||
|
||
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`. | ||
|
||
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed: | ||
|
||
```bash | ||
pip install -Uq deepspeed | ||
``` | ||
|
||
And then launch: | ||
|
||
```bash | ||
accelerate launch --config_file=ds2.yaml \ | ||
train_dreambooth_lora_flux_miniature.py \ | ||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ | ||
--data_df_path="embeddings.parquet" \ | ||
--output_dir="yarn_art_lora_flux_nf4" \ | ||
--mixed_precision="no" \ | ||
--use_8bit_adam \ | ||
--weighting_scheme="none" \ | ||
--resolution=1024 \ | ||
--train_batch_size=1 \ | ||
--repeats=1 \ | ||
--learning_rate=1e-4 \ | ||
--guidance_scale=1 \ | ||
--report_to="wandb" \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--cache_latents \ | ||
--rank=4 \ | ||
--max_train_steps=700 \ | ||
--seed="0" | ||
``` | ||
|
||
## Inference | ||
|
||
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example: | ||
|
||
1. First, load the original model and merge the LoRA params into it: | ||
|
||
```py | ||
from diffusers import FluxPipeline | ||
import torch | ||
|
||
ckpt_id = "black-forest-labs/FLUX.1-dev" | ||
pipeline = FluxPipeline.from_pretrained( | ||
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16 | ||
) | ||
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors") | ||
pipeline.fuse_lora() | ||
pipeline.unload_lora_weights() | ||
|
||
pipeline.transformer.save_pretrained("fused_transformer") | ||
``` | ||
|
||
2. Quantize the model and run inference | ||
|
||
```py | ||
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig | ||
import torch | ||
|
||
ckpt_id = "black-forest-labs/FLUX.1-dev" | ||
bnb_4bit_compute_dtype = torch.float16 | ||
nf4_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, | ||
) | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
"fused_transformer", | ||
quantization_config=nf4_config, | ||
torch_dtype=bnb_4bit_compute_dtype, | ||
) | ||
pipeline = AutoPipelineForText2Image.from_pretrained( | ||
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype | ||
) | ||
pipeline.enable_model_cpu_offload() | ||
|
||
image = pipeline( | ||
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768 | ||
).images[0] | ||
image.save("yarn_merged.png") | ||
``` | ||
|
||
| Dequantize, merge, quantize | Merging directly into quantized model | | ||
|-------|-------| | ||
|  |  | | ||
|
||
As we can notice the first column result follows the style more closely. |
17 changes: 17 additions & 0 deletions
17
examples/research_projects/flux_lora_quantization/accelerate.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
compute_environment: LOCAL_MACHINE | ||
debug: false | ||
distributed_type: NO | ||
downcast_bf16: 'no' | ||
enable_cpu_affinity: true | ||
gpu_ids: all | ||
machine_rank: 0 | ||
main_training_function: main | ||
mixed_precision: bf16 | ||
num_machines: 1 | ||
num_processes: 1 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_env: [] | ||
tpu_use_cluster: false | ||
tpu_use_sudo: false | ||
use_cpu: false |
107 changes: 107 additions & 0 deletions
107
examples/research_projects/flux_lora_quantization/compute_embeddings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
import pandas as pd | ||
import torch | ||
from datasets import load_dataset | ||
from huggingface_hub.utils import insecure_hashlib | ||
from tqdm.auto import tqdm | ||
from transformers import T5EncoderModel | ||
|
||
from diffusers import FluxPipeline | ||
|
||
|
||
MAX_SEQ_LENGTH = 77 | ||
OUTPUT_PATH = "embeddings.parquet" | ||
|
||
|
||
def generate_image_hash(image): | ||
return insecure_hashlib.sha256(image.tobytes()).hexdigest() | ||
|
||
|
||
def load_flux_dev_pipeline(): | ||
id = "black-forest-labs/FLUX.1-dev" | ||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto") | ||
pipeline = FluxPipeline.from_pretrained( | ||
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced" | ||
) | ||
return pipeline | ||
|
||
|
||
@torch.no_grad() | ||
def compute_embeddings(pipeline, prompts, max_sequence_length): | ||
all_prompt_embeds = [] | ||
all_pooled_prompt_embeds = [] | ||
all_text_ids = [] | ||
for prompt in tqdm(prompts, desc="Encoding prompts."): | ||
( | ||
prompt_embeds, | ||
pooled_prompt_embeds, | ||
text_ids, | ||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length) | ||
all_prompt_embeds.append(prompt_embeds) | ||
all_pooled_prompt_embeds.append(pooled_prompt_embeds) | ||
all_text_ids.append(text_ids) | ||
|
||
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 | ||
print(f"Max memory allocated: {max_memory:.3f} GB") | ||
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids | ||
|
||
|
||
def run(args): | ||
dataset = load_dataset("Norod78/Yarn-art-style", split="train") | ||
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset} | ||
all_prompts = list(image_prompts.values()) | ||
print(f"{len(all_prompts)=}") | ||
|
||
pipeline = load_flux_dev_pipeline() | ||
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings( | ||
pipeline, all_prompts, args.max_sequence_length | ||
) | ||
|
||
data = [] | ||
for i, (image_hash, _) in enumerate(image_prompts.items()): | ||
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i])) | ||
print(f"{len(data)=}") | ||
|
||
# Create a DataFrame | ||
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"] | ||
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols) | ||
print(f"{len(df)=}") | ||
|
||
# Convert embedding lists to arrays (for proper storage in parquet) | ||
for col in embedding_cols: | ||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) | ||
|
||
# Save the dataframe to a parquet file | ||
df.to_parquet(args.output_path) | ||
print(f"Data successfully serialized to {args.output_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--max_sequence_length", | ||
type=int, | ||
default=MAX_SEQ_LENGTH, | ||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", | ||
) | ||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") | ||
args = parser.parse_args() | ||
|
||
run(args) |
23 changes: 23 additions & 0 deletions
23
examples/research_projects/flux_lora_quantization/ds2.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
compute_environment: LOCAL_MACHINE | ||
debug: false | ||
deepspeed_config: | ||
gradient_accumulation_steps: 1 | ||
gradient_clipping: 1.0 | ||
offload_optimizer_device: cpu | ||
offload_param_device: cpu | ||
zero3_init_flag: false | ||
zero_stage: 2 | ||
distributed_type: DEEPSPEED | ||
downcast_bf16: 'no' | ||
enable_cpu_affinity: false | ||
machine_rank: 0 | ||
main_training_function: main | ||
mixed_precision: 'no' | ||
num_machines: 1 | ||
num_processes: 1 | ||
rdzv_backend: static | ||
same_network: true | ||
tpu_env: [] | ||
tpu_use_cluster: false | ||
tpu_use_sudo: false | ||
use_cpu: false |
Oops, something went wrong.