Skip to content

Commit

Permalink
full AlphaFold3 training
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 10, 2024
1 parent 12887c4 commit 732ebe0
Show file tree
Hide file tree
Showing 21 changed files with 793 additions and 304 deletions.
92 changes: 46 additions & 46 deletions configs/data/erebor.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
_target_: src.data.data_modules.OpenFoldDataModule

# Train data paths
template_mmcif_dir: "pdb_data/mmcif_files"
template_mmcif_dir: "../pdb_data/mmcif_files"
max_template_date: "2021-10-10"
train_data_dir: "pdb_data/mmcif_files"
train_alignment_dir: "alignment_data/alignments"
train_chain_data_cache_path: "pdb_data/chain_data_cache.json"
template_release_dates_cache_path: "pdb_data/mmcif_cache.json"
train_mmcif_data_cache_path: "pdb_data/mmcif_cache.json"
alignment_index_path: null # "alignment_data/alignment_dbs/alignment_db.index"
obsolete_pdbs_file_path: "pdb_data/obsolete.dat"
train_data_dir: "../pdb_data/mmcif_files"
train_alignment_dir: "../alignment_data/alignment_dbs"
train_chain_data_cache_path: "../data_caches/chain_data_cache.json"
template_release_dates_cache_path: "../data_caches/mmcif_cache.json"
train_mmcif_data_cache_path: "../data_caches/mmcif_cache.json"
alignment_index_path: "../alignment_data/alignment_dbs/alignment_db.index"
obsolete_pdbs_file_path: "../pdb_data/obsolete.dat"
train_filter_path: null

# Distillation data paths
Expand Down Expand Up @@ -107,98 +107,98 @@ config:
common:
feat: # Features for AlphaFold 3, single chain, backbone only coordinates
aatype: # [NUM_RES]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
all_atom_mask: # [NUM_RES, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
all_atom_positions: # [NUM_RES, 4, 3]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
- null
ref_pos: # [NUM_RES, 4, 3]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
- null
ref_mask: # [NUM_RES, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
ref_element: # [NUM_RES, 4, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
- null
ref_charge: # [NUM_RES, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
ref_atom_name_chars: # [NUM_RES, 4, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
- null
ref_space_uid: # [NUM_RES, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
atom_to_token: # [NUM_RES, 4]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
is_distillation: [ ]
msa_feat: # [NUM_MSA_SEQ, NUM_RES, 49]
- ${placeholders.NUM_MSA_SEQ}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_MSA_SEQ}
- ${data.placeholders.NUM_RES}
- null
msa_mask: # [NUM_MSA_SEQ, NUM_RES]
- ${placeholders.NUM_MSA_SEQ}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_MSA_SEQ}
- ${data.placeholders.NUM_RES}
residue_index: # [NUM_RES]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
residx_atom14_to_atom37:
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
residx_atom37_to_atom14:
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
- null
resolution: [ ]
seq_length: [ ]
seq_mask: # [NUM_RES]
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
template_aatype: # [NUM_TEMPLATES, NUM_RES]
- ${placeholders.NUM_TEMPLATES},
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES},
- ${data.placeholders.NUM_RES}
template_all_atom_mask: # [NUM_TEMPLATES, NUM_RES, 4]
- ${placeholders.NUM_TEMPLATES},
- ${placeholders.NUM_RES},
- ${data.placeholders.NUM_TEMPLATES},
- ${data.placeholders.NUM_RES},
- null
template_all_atom_positions:
- ${placeholders.NUM_TEMPLATES}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_RES}
- null
- null
template_backbone_rigid_mask:
- ${placeholders.NUM_TEMPLATES}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_RES}
template_backbone_rigid_tensor:
- ${placeholders.NUM_TEMPLATES}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_RES}
- null
- null
template_mask:
- ${placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_TEMPLATES}
template_pseudo_beta:
- ${placeholders.NUM_TEMPLATES}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_RES}
- null
template_pseudo_beta_mask:
- ${placeholders.NUM_TEMPLATES}
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_RES}
template_sum_probs:
- ${placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_TEMPLATES}
- null
valid_residues:
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
valid_templates:
- ${placeholders.NUM_TEMPLATES}
- ${data.placeholders.NUM_TEMPLATES}
valid_msa:
- ${placeholders.NUM_MSA_SEQ}
- ${data.placeholders.NUM_MSA_SEQ}
valid_backbone_atoms:
- ${placeholders.NUM_RES}
- ${data.placeholders.NUM_RES}
block_delete_msa:
msa_fraction_per_block: 0.3
randomize_num_blocks: false
Expand Down
8 changes: 3 additions & 5 deletions configs/model/alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ model:

# MSA module
msa_module:
no_blocks: 1 # 4
no_blocks: 4
c_msa: 64
c_token: ${model.model.c_token}
c_z: ${model.model.c_pair}
Expand All @@ -107,7 +107,7 @@ model:
pairformer_stack:
c_s: ${model.model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 1 # 48
no_blocks: 48
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
Expand Down Expand Up @@ -140,7 +140,7 @@ model:
p: 7.0
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
compile_model: false # TODO: this parameter does not belong here, will be removed in future iterations.
compile_module: false # TODO: this parameter does not belong here, will be removed in future iterations.

confidence_head:
c_s: 384 # ${model.c_token}
Expand Down Expand Up @@ -168,8 +168,6 @@ globals:
chunk_size: null # 4
# Use DeepSpeed memory-efficient attention kernel in supported modules.
use_deepspeed_evo_attention: true
# Use FlashAttention in selected modules.
use_flash: true
samples_per_trunk: 48 # Number of diffusion module replicas per trunk
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
173 changes: 173 additions & 0 deletions configs/model/mini-af3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# AlphaFold3 configs
# _target_: src.models.model_wrapper.AlphaFoldWrapper
# config:
optimizer:
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam
_partial_: true
lr: 0.0018 # 0.0018
betas:
- 0.9
- 0.95
eps: 1e-08
weight_decay: 0.0

scheduler:
_target_: src.utils.lr_schedulers.AlphaFoldLRScheduler
_partial_: true
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.0018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
decay_factor: 0.95

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
use_smooth_lddt: false
sd_data: 16.0
weight: 1.0

distogram:
min_bin: 2.3125
max_bin: 21.6875
no_bins: 64
eps: 0.000006 # 1e-6
weight: 0.03

experimentally_resolved:
eps: 0.00000001 # 1e-8,
# min_resolution: 0.1,
# max_resolution: 3.0,
weight: 0.0004

plddt_loss:
min_resolution: 0.1
max_resolution: 3.0
cutoff: 15.0
no_bins": 50
eps: 0.0000000001 # 1e-10,
weight: 0.0004


# TODO: fix the model.model notation in interpolation
model:
c_token: 64 # the token representation dim
c_pair: 16 # the pair representation dim
c_atom: 16 # the atom representation dim
c_atompair: 16 # the atom pair representation dim

# Pair stack parameters (used in Pairformer, MSA module, and Confidence head)
c_hidden_tri_mul: 16 # the hidden dim for the triangle multiplicative update
c_hidden_pair_attn: 16 # the hidden dim for the pair attention ${common.c_hidden_pair_attn}
no_heads_tri_attn: 1
transition_n: 1
pair_dropout: 0.25
fuse_projection_weights: false
blocks_per_ckpt: null # number of blocks per checkpoint, if none, no checkpointing
clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks
# Pairformer attention pair bias
no_heads_single_attn: 1

# Input Embedder
input_embedder:
c_token: ${model.model.c_token}
c_trunk_pair: ${model.model.c_pair}
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}

# MSA module
msa_module:
no_blocks: 1 # 4
c_msa: 16
c_token: ${model.model.c_token}
c_z: ${model.model.c_pair}
c_hidden: 8
no_heads: 1
c_hidden_tri_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
inf: 1e8

# PairformerStack
pairformer_stack:
c_s: ${model.model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 1 # 48
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
clear_cache_between_blocks: false
inf: 1e8

# Diffusion module
diffusion_module:
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}
c_token: ${model.model.c_token}
c_tokenpair: ${model.model.c_pair}
atom_encoder_blocks: 1
atom_encoder_heads: 1
dropout: 0.0
atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used.
atom_attention_n_keys: 128
atom_decoder_blocks: 1
atom_decoder_heads: 1
token_transformer_blocks: 1 # 24
token_transformer_heads: 1
sd_data: 16.0
s_max: 160.0
s_min: 4e-4
p: 7.0
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
compile_module: false # TODO: this parameter does not belong here, will be removed in future iterations.

confidence_head:
c_s: 384 # ${model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 4
no_bins_pde: 64
no_bins_plddt: 64
no_bins_pae: 64
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}

distogram_head:
c_z: ${model.model.c_pair}
no_bins: 64

# Exponential moving average decay rate
ema_decay: 0.999

globals:
chunk_size: null # 4
# Use DeepSpeed memory-efficient attention kernel in supported modules.
use_deepspeed_evo_attention: true
samples_per_trunk: 48 # Number of diffusion module replicas per trunk
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
6 changes: 3 additions & 3 deletions configs/train.yaml → configs/train_af3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
defaults:
- _self_
- callbacks: default
- data: protein
- model: proteus # alphafold3
- data: erebor
- model: mini-af3
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: gpu
- trainer: cpu
- paths: default
- extras: default
- hydra: default
Expand Down
Loading

0 comments on commit 732ebe0

Please sign in to comment.