Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
danielegrattarola committed Nov 5, 2021
2 parents 6c77083 + 028cc95 commit d077e57
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion spektral/layers/convolutional/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def _call_dense(self, x, a):
attn_coef = attn_for_self + attn_for_neighs
attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)

mask = -10e9 * (1.0 - a)
mask = tf.where(a == 0.0, -10e9, 0.0)
mask = tf.cast(mask, dtype=attn_coef.dtype)
attn_coef += mask[..., None, :]
attn_coef = tf.nn.softmax(attn_coef, axis=-1)
attn_coef_drop = self.dropout(attn_coef)
Expand Down

0 comments on commit d077e57

Please sign in to comment.