Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why doesn't cross-attention use normalization in PixArtMSBlock? #160

Open
binbinsh opened this issue Jan 10, 2025 · 0 comments
Open

Why doesn't cross-attention use normalization in PixArtMSBlock? #160

binbinsh opened this issue Jan 10, 2025 · 0 comments

Comments

@binbinsh
Copy link

binbinsh commented Jan 10, 2025

I noticed that in the PixArtMSBlock implementation, there is no normalization layer for cross-attention, while normalization layers exist for self-attention and MLP:

self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # for self-attention
self.attn = AttentionKVCompress(...)
self.cross_attn = MultiHeadCrossAttention(...)  # no norm layer before/after
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # for MLP

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py#L541:

        # 3. Cross-Attention
        if self.attn2 is not None:
            if self.norm_type == "ada_norm":
                norm_hidden_states = self.norm2(hidden_states, timestep)
            elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
                norm_hidden_states = self.norm2(hidden_states)
            elif self.norm_type == "ada_norm_single":
                # For PixArt norm2 isn't applied here:
                # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
                norm_hidden_states = hidden_states
            elif self.norm_type == "ada_norm_continuous":
                norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
            else:
                raise ValueError("Incorrect norm")

I'm curious about the reasoning behind not using normalization for cross-attention, while having it for self-attention and MLP layers. What's the rationale for this architectural design?

Thanks for this great work!

@binbinsh binbinsh reopened this Jan 11, 2025
@binbinsh binbinsh changed the title Question about missing normalization layer for cross-attention in PixArtMSBlock Why doesn't cross-attention use normalization in PixArtMSBlock? Jan 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant