Skip to content

Commit

Permalink
[research_projects] add flux training script with quantization (huggi…
Browse files Browse the repository at this point in the history
…ngface#9754)

* add flux training script with quantization

* remove exclamation
  • Loading branch information
sayakpaul authored Oct 25, 2024
1 parent 94643fa commit df073ba
Show file tree
Hide file tree
Showing 5 changed files with 1,496 additions and 0 deletions.
166 changes: 166 additions & 0 deletions examples/research_projects/flux_lora_quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
## LoRA fine-tuning Flux.1 Dev with quantization

> [!NOTE]
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:

* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
* Since we already precomputed the text embeddings, we don't load the text encoders.
* We load the VAE and use it to precompute the image latents and we then delete it.
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
* Train!

To run training in a memory-optimized manner, we additionally use:

* 8Bit Adam
* Gradient checkpointing

We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.

## Training

Ensure you have installed the required libraries:

```bash
pip install -U transformers accelerate bitsandbytes peft datasets
pip install git+https://github.com/huggingface/diffusers -U
```

Now, compute the text embeddings:

```bash
python compute_embeddings.py
```

It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:

```bash
huggingface-cli
```

Then launch:

```bash
accelerate launch --config_file=accelerate.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="fp16" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
```

We can direcly pass a quantized checkpoint path, too:

```diff
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
```

Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.

We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:

```bash
pip install -Uq deepspeed
```

And then launch:

```bash
accelerate launch --config_file=ds2.yaml \
train_dreambooth_lora_flux_miniature.py \
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
--data_df_path="embeddings.parquet" \
--output_dir="yarn_art_lora_flux_nf4" \
--mixed_precision="no" \
--use_8bit_adam \
--weighting_scheme="none" \
--resolution=1024 \
--train_batch_size=1 \
--repeats=1 \
--learning_rate=1e-4 \
--guidance_scale=1 \
--report_to="wandb" \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--cache_latents \
--rank=4 \
--max_train_steps=700 \
--seed="0"
```

## Inference

When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:

1. First, load the original model and merge the LoRA params into it:

```py
from diffusers import FluxPipeline
import torch

ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline = FluxPipeline.from_pretrained(
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
)
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()

pipeline.transformer.save_pretrained("fused_transformer")
```

2. Quantize the model and run inference

```py
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
import torch

ckpt_id = "black-forest-labs/FLUX.1-dev"
bnb_4bit_compute_dtype = torch.float16
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)
transformer = FluxTransformer2DModel.from_pretrained(
"fused_transformer",
quantization_config=nf4_config,
torch_dtype=bnb_4bit_compute_dtype,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
)
pipeline.enable_model_cpu_offload()

image = pipeline(
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
).images[0]
image.save("yarn_merged.png")
```

| Dequantize, merge, quantize | Merging directly into quantized model |
|-------|-------|
| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) |

As we can notice the first column result follows the style more closely.
17 changes: 17 additions & 0 deletions examples/research_projects/flux_lora_quantization/accelerate.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: NO
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import pandas as pd
import torch
from datasets import load_dataset
from huggingface_hub.utils import insecure_hashlib
from tqdm.auto import tqdm
from transformers import T5EncoderModel

from diffusers import FluxPipeline


MAX_SEQ_LENGTH = 77
OUTPUT_PATH = "embeddings.parquet"


def generate_image_hash(image):
return insecure_hashlib.sha256(image.tobytes()).hexdigest()


def load_flux_dev_pipeline():
id = "black-forest-labs/FLUX.1-dev"
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
pipeline = FluxPipeline.from_pretrained(
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
)
return pipeline


@torch.no_grad()
def compute_embeddings(pipeline, prompts, max_sequence_length):
all_prompt_embeds = []
all_pooled_prompt_embeds = []
all_text_ids = []
for prompt in tqdm(prompts, desc="Encoding prompts."):
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
all_prompt_embeds.append(prompt_embeds)
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
all_text_ids.append(text_ids)

max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f"Max memory allocated: {max_memory:.3f} GB")
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids


def run(args):
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
all_prompts = list(image_prompts.values())
print(f"{len(all_prompts)=}")

pipeline = load_flux_dev_pipeline()
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
pipeline, all_prompts, args.max_sequence_length
)

data = []
for i, (image_hash, _) in enumerate(image_prompts.items()):
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
print(f"{len(data)=}")

# Create a DataFrame
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
print(f"{len(df)=}")

# Convert embedding lists to arrays (for proper storage in parquet)
for col in embedding_cols:
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())

# Save the dataframe to a parquet file
df.to_parquet(args.output_path)
print(f"Data successfully serialized to {args.output_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_sequence_length",
type=int,
default=MAX_SEQ_LENGTH,
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
)
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
args = parser.parse_args()

run(args)
23 changes: 23 additions & 0 deletions examples/research_projects/flux_lora_quantization/ds2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Loading

0 comments on commit df073ba

Please sign in to comment.