Skip to content

Commit

Permalink
tweak to atten probe
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Ororbia committed Mar 6, 2025
1 parent f38373f commit 012395b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ngclearn/utils/analysis/attentive_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
score = score.astype(q.dtype) # (B, H, T, S)
if dropout_rate > 0.:
score = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
score, _ = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
return attention @ Wout + bout # (B, T, Dq)
Expand Down

0 comments on commit 012395b

Please sign in to comment.