diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..b64920c374f4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -20,7 +20,11 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import ( + FluxTransformer2DLoadersMixin, + FromOriginalModelMixin, + PeftAdapterMixin, +) from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, @@ -30,21 +34,42 @@ FusedFluxAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, + TimestepEmbedding, + Timesteps, +) from ..modeling_outputs import Transformer2DModelOutput - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) @@ -82,6 +107,7 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) @@ -90,6 +116,7 @@ def forward( attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -106,7 +133,12 @@ def forward( @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, ): super().__init__() @@ -131,7 +163,9 @@ def __init__( self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) def forward( self, @@ -140,11 +174,14 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) ) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. @@ -152,6 +189,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -165,7 +203,9 @@ def forward( hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -180,10 +220,15 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) @@ -191,7 +236,12 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, ): """ The Transformer model introduced in Flux. @@ -241,6 +291,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, + additional_timestep_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() @@ -250,12 +301,23 @@ def __init__( self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + CombinedTimestepGuidanceTextProjEmbeddings + if guidance_embeds + else CombinedTimestepTextProjEmbeddings ) + self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) + if additional_timestep_embeds: + self.additional_time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.additional_timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False + ) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim) @@ -281,8 +343,12 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) self.gradient_checkpointing = False @@ -297,7 +363,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -312,7 +382,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): r""" Sets the attention processor to use to compute attention. @@ -362,7 +434,9 @@ def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) self.original_attn_processors = self.attn_processors @@ -395,11 +469,13 @@ def forward( img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, + additional_timestep_embeddings: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False, + attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -437,7 +513,10 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) @@ -455,6 +534,16 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) + + if additional_timestep_embeddings is not None: + additional_timestep_embed_proj = self.additional_time_proj( + additional_timestep_embeddings + ) + additional_class_emb = self.additional_timestep_embedder( + additional_timestep_embed_proj.to(dtype=temb.dtype) + ) + temb = temb + additional_class_emb + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: @@ -473,19 +562,27 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) - if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: - ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + if ( + joint_attention_kwargs is not None + and "ip_adapter_image_embeds" in joint_attention_kwargs + ): + ip_adapter_image_embeds = joint_attention_kwargs.pop( + "ip_adapter_image_embeds" + ) ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, + encoder_hidden_states, hidden_states = ( + self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask=attention_mask, + ) ) else: @@ -495,19 +592,28 @@ def forward( temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + attention_mask=attention_mask, ) # controlnet residual if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. if controlnet_blocks_repeat: hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + hidden_states + + controlnet_block_samples[ + index_block % len(controlnet_block_samples) + ] ) else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): @@ -517,6 +623,7 @@ def forward( hidden_states, temb, image_rotary_emb, + attention_mask=attention_mask, ) else: @@ -525,11 +632,14 @@ def forward( temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + attention_mask=attention_mask, ) # controlnet residual if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = len(self.single_transformer_blocks) / len( + controlnet_single_block_samples + ) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...]