Skip to content

Commit

Permalink
Add the residual connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 29, 2024
1 parent 8cbfb72 commit 0e817b7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
14 changes: 5 additions & 9 deletions src/models/components/atom_attention_naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@
We did early experiments with a PyTorch-native implementation that is supposed to use memory more efficiently,
but they did not show much benefit since attention implementations in PyTorch were much slower despite
adding considerable clutter and complexity. We fall back to the Deepspeed4Science optimized attention kernel, which reduce
the memory consumption to linear anyway.
the memory consumption to linear anyway. In practice, we only observe about a 20% increase in runtime and comparable memory usage.
However, this is not recommended for large scale training.
The smart move here will be to migrate to FlexAttention once there is bias gradient support.
AtomTransformer
AtomAttentionEncoder
AtomAttentionDecoder
This is not recommended for large scale training.
The smart move here will be to migrate to FlexAttention once there is bias gradient support or to ScaleFold's kernels if they become available.
"""
import torch
from torch import Tensor
Expand Down Expand Up @@ -101,12 +97,12 @@ def forward(
betas = self._prep_betas(n_atoms, atom_single.device) # (1, n_atoms, n_atoms)

# AttentionPairBias
b = self.attention(
atom_single = atom_single + self.attention(
atom_single, atom_proj, atom_pair, mask, betas,
use_deepspeed_evo_attention=use_deepspeed_evo_attention
)
# ConditionedTransitionBlock
atom_single = b + self.transition(atom_single, atom_proj)
atom_single = atom_single + self.transition(atom_single, atom_proj)
return atom_single, atom_proj, atom_pair


Expand Down
2 changes: 1 addition & 1 deletion src/models/components/outer_product_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
self.layer_norm = nn.LayerNorm(c_m)
self.linear_1 = Linear(c_m, c_hidden)
self.linear_2 = Linear(c_m, c_hidden)
self.linear_out = Linear(c_hidden ** 2, c_z, init="default")
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")

def _opm(self, a, b):
# [*, N_res, N_res, C, C]
Expand Down
4 changes: 2 additions & 2 deletions src/models/diffusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(
TODO: the single_proj and pair_repr do not actually change as a result of this function.
Returning them here is a bit misleading. Also, saving them between blocks is unnecessary.
"""
b = self.attention_block(
single_repr = single_repr + self.attention_block(
single_repr=single_repr,
single_proj=single_proj,
pair_repr=pair_repr,
Expand All @@ -69,7 +69,7 @@ def forward(
)

single_repr = add(
b,
single_repr,
self.conditioned_transition_block(single_repr, single_proj),
inplace=False
)
Expand Down

0 comments on commit 0e817b7

Please sign in to comment.