Skip to content

Commit

Permalink
removed residual connection from Diffusion module components
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 15, 2024
1 parent 19a5325 commit c383351
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
2 changes: 1 addition & 1 deletion configs/train_af3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defaults:
- data: erebor
- model: alphafold3
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: cpu
- trainer: ddp
- paths: default
- extras: default
- hydra: default
Expand Down
6 changes: 1 addition & 5 deletions src/models/components/atom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,7 @@ def forward(
atom_pair_local: Tensor,
mask: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor]:
a = add(
atom_single,
self.atom_attention(atom_single, atom_proj, atom_pair_local, mask),
inplace=False
)
a = self.atom_attention(atom_single, atom_proj, atom_pair_local, mask),
atom_single = add(a, self.transition(atom_single, atom_proj), inplace=False)
return atom_single, atom_proj, atom_pair_local

Expand Down
18 changes: 9 additions & 9 deletions src/models/diffusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
from src.models.components.transition import ConditionedTransitionBlock
from src.models.components.attention_pair_bias import AttentionPairBias
from src.models.components.primitives import LayerNorm
from typing import Optional
from functools import partial
from src.utils.checkpointing import checkpoint_blocks
Expand Down Expand Up @@ -58,21 +59,20 @@ def forward(
TODO: the single_proj and pair_repr do not actually change as a result of this function.
Returning them here is a bit misleading. Also, saving them between blocks is unnecessary.
"""
b = add( # TODO: this residual connection does not exist in the paper!
single_repr,
self.attention_block(
single_repr=single_repr,
single_proj=single_proj,
pair_repr=pair_repr,
mask=mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention),
inplace=False
b = self.attention_block(
single_repr=single_repr,
single_proj=single_proj,
pair_repr=pair_repr,
mask=mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention
)

single_repr = add(
b,
self.conditioned_transition_block(single_repr, single_proj),
inplace=False
)

return single_repr, single_proj, pair_repr


Expand Down

0 comments on commit c383351

Please sign in to comment.