Skip to content

Adds FLUX attention masking and additional time embedder #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 144 additions & 34 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import (
FluxTransformer2DLoadersMixin,
FromOriginalModelMixin,
PeftAdapterMixin,
)
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
Expand All @@ -30,21 +34,42 @@
FusedFluxAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...models.normalization import (
AdaLayerNormContinuous,
AdaLayerNormZero,
AdaLayerNormZeroSingle,
)
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
FluxPosEmbed,
TimestepEmbedding,
Timesteps,
)
from ..modeling_outputs import Transformer2DModelOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)

Expand Down Expand Up @@ -82,6 +107,7 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
Expand All @@ -90,6 +116,7 @@ def forward(
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand All @@ -106,7 +133,12 @@ def forward(
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
):
super().__init__()

Expand All @@ -131,7 +163,9 @@ def __init__(
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.ff_context = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)

def forward(
self,
Expand All @@ -140,18 +174,22 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb
)

norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
self.norm1_context(encoder_hidden_states, emb=temb)
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)

Expand All @@ -165,7 +203,9 @@ def forward(
hidden_states = hidden_states + attn_output

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)

ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
Expand All @@ -180,18 +220,28 @@ def forward(
encoder_hidden_states = encoder_hidden_states + context_attn_output

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
norm_encoder_hidden_states = (
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
+ c_shift_mlp[:, None]
)

context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
encoder_hidden_states = (
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
)
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

return encoder_hidden_states, hidden_states


class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
):
"""
The Transformer model introduced in Flux.
Expand Down Expand Up @@ -241,6 +291,7 @@ def __init__(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
additional_timestep_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
Expand All @@ -250,12 +301,23 @@ def __init__(
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
CombinedTimestepGuidanceTextProjEmbeddings
if guidance_embeds
else CombinedTimestepTextProjEmbeddings
)

self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)

if additional_timestep_embeds:
self.additional_time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.additional_timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False
)

self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)

Expand All @@ -281,8 +343,12 @@ def __init__(
]
)

self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.norm_out = AdaLayerNormContinuous(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
)

self.gradient_checkpointing = False

Expand All @@ -297,7 +363,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]:
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()

Expand All @@ -312,7 +382,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.

Expand Down Expand Up @@ -362,7 +434,9 @@ def fuse_qkv_projections(self):

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
raise ValueError(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)

self.original_attn_processors = self.attn_processors

Expand Down Expand Up @@ -395,11 +469,13 @@ def forward(
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
additional_timestep_embeddings: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Expand Down Expand Up @@ -437,7 +513,10 @@ def forward(
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
Expand All @@ -455,6 +534,16 @@ def forward(
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)

if additional_timestep_embeddings is not None:
additional_timestep_embed_proj = self.additional_time_proj(
additional_timestep_embeddings
)
additional_class_emb = self.additional_timestep_embedder(
additional_timestep_embed_proj.to(dtype=temb.dtype)
)
temb = temb + additional_class_emb

encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if txt_ids.ndim == 3:
Expand All @@ -473,19 +562,27 @@ def forward(
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
if (
joint_attention_kwargs is not None
and "ip_adapter_image_embeds" in joint_attention_kwargs
):
ip_adapter_image_embeds = joint_attention_kwargs.pop(
"ip_adapter_image_embeds"
)
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_hidden_states, hidden_states = (
self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask=attention_mask,
)
)

else:
Expand All @@ -495,19 +592,28 @@ def forward(
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
attention_mask=attention_mask,
)

# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = len(self.transformer_blocks) / len(
controlnet_block_samples
)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
hidden_states
+ controlnet_block_samples[
index_block % len(controlnet_block_samples)
]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = (
hidden_states
+ controlnet_block_samples[index_block // interval_control]
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
Expand All @@ -517,6 +623,7 @@ def forward(
hidden_states,
temb,
image_rotary_emb,
attention_mask=attention_mask,
)

else:
Expand All @@ -525,11 +632,14 @@ def forward(
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
attention_mask=attention_mask,
)

# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = len(self.single_transformer_blocks) / len(
controlnet_single_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
Expand Down