Open
Description
Hi, I want to run fluxcontrolpipeline with transformer_fp8 reference the code :
https://huggingface.co/docs/diffusers/api/pipelines/flux#quantization
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxControlPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt, guidance_scale=3.5, height=768, width=1360, num_inference_steps=50).images[0]
image.save("flux.png")
but when I load lora after build a pipeline
pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora")
There a error:
not support fp8 weight , how to fix it??
Metadata
Metadata
Assignees
Labels
No labels