Skip to content

Commit 012395b

Browse files
author
Alexander Ororbia
committed
tweak to atten probe
1 parent f38373f commit 012395b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def cross_attention(dkey, params: tuple, x1: jax.Array, x2: jax.Array, mask: jax
7070
score = jax.nn.softmax(score, axis=-1) # (B, H, T, S)
7171
score = score.astype(q.dtype) # (B, H, T, S)
7272
if dropout_rate > 0.:
73-
score = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
73+
score, _ = drop_out(dkey, input=score, rate=dropout_rate) ## NOTE: normally you apply dropout here
7474
attention = jnp.einsum("BHTS,BHSE->BHTE", score, v) # (B, T, H, E)
7575
attention = attention.transpose([0, 2, 1, 3]).reshape((B, T, -1)) # (B, T, H, E) => (B, T, D)
7676
return attention @ Wout + bout # (B, T, Dq)

0 commit comments

Comments
 (0)