Skip to content

Scripts to teach Flux the task of image editing from language with the Flux Control framework.

License

Notifications You must be signed in to change notification settings

sayakpaul/flux-image-editing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FluxEdit

This project tries to teach Flux.1 Dev the task of image editing from language with the Flux Control framework. We use the high-quality TIGER-Lab/OmniEdit-Filtered-1.2M dataset for training.

GIF

Important

Since we don't have the official Flux Control training details available, this project should be considered experimental and we welcome contributions from the community to make it better 🤗

Setup

Install the dependencies from requirements.txt and perform any other configuration that might be needed.

The scripts were tested using PyTorch 2.5.1 and NVIDIA GPUs (H100).

Training

We first converted the original OmniEdit dataset into Webdataset shards using this script for efficiency. This script prepares the Webdataset shards and push them to an S3 bucket. But you can configure this as per your needs.

Training Command
export LR=1e-4
export WEIGHT_DECAY=1e-4
export GUIDANCE_SCALE=30.0
export CAPTION_DROPOUT=0.0
export LR_SCHEDULER="constant"

srun --wait=60 --kill-on-bad-exit=1 accelerate launch --config_file=./misc/accelerate_ds2.yaml train_control_flux_wds.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
  --output_dir="omniflux-lr_${LR}-wd_${WEIGHT_DECAY}-gs_${GUIDANCE_SCALE}-cd_${CAPTION_DROPOUT}-scheduler_${LR_SCHEDULER}-sim_flow-no8bitadam" \
  --mixed_precision="bf16" \
  --per_gpu_batch_size=4 \
  --dataloader_num_workers=4 \
  --gradient_accumulation_steps=4 \
  --quality_threshold=10.0 \
  --simplified_flow \
  --gradient_checkpointing \
  --proportion_empty_prompts=$CAPTION_DROPOUT \
  --learning_rate=$LR \
  --adam_weight_decay=$WEIGHT_DECAY \
  --guidance_scale=$GUIDANCE_SCALE \
  --validation_guidance_scale="10.,20.,30.,40." \
  --report_to="wandb" --log_dataset_samples \
  --lr_scheduler=$LR_SCHEDULER \
  --lr_warmup_steps=0 \
  --checkpointing_steps=4000 \
  --resume_from_checkpoint="latest" --checkpoints_total_limit=2 \
  --max_train_steps=20000 \
  --validation_steps=400 \
  --validation_image "assets/car.jpg" "assets/green_creature.jpg" "assets/norte_dam.jpg" "assets/mushroom.jpg" \
  --validation_prompt "Give this the look of a traditional Japanese woodblock print." "transform the setting to a winter scene" "Change it to look like it's in the style of an impasto painting." "turn the color of mushroom to gray" \
  --seed="0" \
  --push_to_hub

echo "END TIME: $(date)"

Training starts on 8 GPUs using DeepSpeed. You can configure the accelerate config file to change that.

Refer to the args.py to know the different kinds of configurations supported. Training was conducted on a node of 8 H100s. If you prefer using Slurm, refer to this Slurm script for scheduling training.

You can also use this minimal version of the train.py script with a minimal version of the OmniEdit dataset for quicker prototyping.

Inference

from diffusers import FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
import torch 

path = "sayakpaul/FLUX.1-dev-edit-v0"
edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
pipeline = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
).to("cuda")

image = load_image("./assets/mushroom.jpg") # resize as needed.
print(image.size)

prompt = "turn the color of mushroom to gray"
image = pipeline(
    control_image=image,
    prompt=prompt,
    guidance_scale=30., # change this as needed.
    num_inference_steps=50, # change this as needed.
    max_sequence_length=512,
    height=image.height,
    width=image.width,
    generator=torch.manual_seed(0)
).images[0]
image.save("edited_image.png")

Speeding inference with a turbo LoRA

We can speed up the inference by reducing the num_inference_steps to produce a nice image by using turbo LoRA like ByteDance/Hyper-SD.

Make sure to install peft before running the code below: pip install -U peft.

Code
from diffusers import FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download
import torch

path = "sayakpaul/FLUX.1-dev-edit-v0" 
edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16)
control_pipe = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16
).to("cuda")

# load the turbo LoRA
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])

image = load_image("./assets/mushroom.jpg") # resize as needed.
print(image.size)

prompt = "turn the color of mushroom to gray"
image = pipeline(
    control_image=image,
    prompt=prompt,
    guidance_scale=30., # change this as needed.
    num_inference_steps=8, # change this as needed.
    max_sequence_length=512,
    height=image.height,
    width=image.width,
    generator=torch.manual_seed(0)
).images[0]
image.save("edited_image.png")

Comparison
50 steps 8 steps
50 steps 1 8 steps 1
50 steps 2 8 steps 2
50 steps 3 8 steps 3
50 steps 4 8 steps 4

You can also choose to perform quantization if the memory requirements cannot be satisfied further w.r.t your hardware. Refer to the Diffusers documentation to learn more.

About

Scripts to teach Flux the task of image editing from language with the Flux Control framework.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published