Skip to content

Commit

Permalink
fix MSA module bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 19, 2024
1 parent 712ec7a commit 70b4854
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/models/msa_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def forward(
if msa_mask is not None:
v = v * msa_mask.unsqueeze(-1).unsqueeze(-1)
new_v_shape = (v.shape[:-4] + (n_seq, n_res, n_res, self.no_heads, self.c_hidden))
v = v.unsqueeze(-3).expand(new_v_shape) # (*, seq, res, res, heads, c_hidden)
v = v.unsqueeze(-4).expand(new_v_shape) # (*, seq, res, res, heads, c_hidden)

# Weighted average with gating
weights = self.softmax(b)
Expand Down

0 comments on commit 70b4854

Please sign in to comment.