Skip to content

Commit

Permalink
distogram bug fix for bs > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 14, 2024
1 parent d87a740 commit a67a67a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/models/components/attention_pair_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(

# Attention
self.attention = Attention(
c_q=dim,
c_q=dim, # TODO: this Q needs to be projected from a linear layer with bias
c_k=dim,
c_v=dim,
c_hidden=dim // no_heads,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def distogram_loss(
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]

# Average over the batch dimensions
mean = torch.mean(mean)
Expand Down

0 comments on commit a67a67a

Please sign in to comment.