Skip to content

Commit

Permalink
added small alphafold3 configs
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 18, 2024
1 parent f547bf8 commit 712ec7a
Showing 1 changed file with 184 additions and 0 deletions.
184 changes: 184 additions & 0 deletions configs/model/small-alphafold3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# AlphaFold3 configs
# _target_: src.models.model_wrapper.AlphaFoldWrapper
# config:
optimizer:
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam
_partial_: true
lr: 0.00018
betas:
- 0.9
- 0.95
eps: 1e-08
weight_decay: 0.0

scheduler:
_target_: src.utils.lr_schedulers.AlphaFoldLRScheduler
_partial_: true
last_epoch: -1
verbose: false
base_lr: 0.0 # the starting learning rate
max_lr: 0.00018
warmup_no_steps: 1000
start_decay_after_n_steps: 50_000
decay_every_n_steps: 50_000
decay_factor: 0.95

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 1.0
epsilon: 1e-5

distogram:
min_bin: 0.0
max_bin: 32.0
no_bins: 64
eps: 0.000006 # 1e-6
weight: 0.03

experimentally_resolved:
eps: 0.00000001 # 1e-8,
# min_resolution: 0.1,
# max_resolution: 3.0,
weight: 0.0004

plddt_loss:
min_resolution: 0.1
max_resolution: 3.0
cutoff: 15.0
no_bins": 50
eps: 0.0000000001 # 1e-10,
weight: 0.0004


# TODO: fix the model.model notation in interpolation
model:
c_token: 384 # the token representation dim
c_pair: 128 # the pair representation dim
c_atom: 128 # the atom representation dim
c_atompair: 16 # the atom pair representation dim

# Pair stack parameters (used in Pairformer, MSA module, and Confidence head)
c_hidden_tri_mul: 128 # the hidden dim for the triangle multiplicative update
c_hidden_pair_attn: 32 # the hidden dim for the pair attention ${common.c_hidden_pair_attn}
no_heads_tri_attn: 4
transition_n: 4
pair_dropout: 0.25
fuse_projection_weights: false
blocks_per_ckpt: 1 # number of blocks per checkpoint, if none, no checkpointing
clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks
# Pairformer attention pair bias
no_heads_single_attn: 16

# Input Embedder
input_embedder:
c_token: ${model.model.c_token}
c_trunk_pair: ${model.model.c_pair}
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}

# MSA module
msa_module:
no_blocks: 4
c_msa: 64
c_token: ${model.model.c_token}
c_z: ${model.model.c_pair}
c_hidden: 32
no_heads: 8
c_hidden_tri_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
inf: 1e8

# Template Embedder
template_embedder:
no_blocks: 2
c_template: 64
c_z: ${model.model.c_pair}
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}

# PairformerStack
pairformer_stack:
c_s: ${model.model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 24
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
clear_cache_between_blocks: false
inf: 1e8

# Diffusion module
diffusion_module:
c_atom: ${model.model.c_atom}
c_atompair: ${model.model.c_atompair}
c_token: ${model.model.c_token}
c_tokenpair: ${model.model.c_pair}
atom_encoder_blocks: 3
atom_encoder_heads: 16
dropout: 0.0
atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used.
atom_attention_n_keys: 128
atom_decoder_blocks: 3
atom_decoder_heads: 16
token_transformer_blocks: 12
token_transformer_heads: 16
sd_data: 16.0
s_max: 160.0
s_min: 0.0004
p: 7.0
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks}
blocks_per_ckpt: ${model.model.blocks_per_ckpt}
compile_module: false # TODO: this parameter does not belong here, will be removed in future iterations.

confidence_head:
c_s: 384 # ${model.c_token}
c_z: ${model.model.c_pair}
no_blocks: 4
no_bins_pde: 64
no_bins_plddt: 64
no_bins_pae: 64
c_hidden_mul: ${model.model.c_hidden_tri_mul}
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn}
no_heads_tri_attn: ${model.model.no_heads_tri_attn}
no_heads_single_attn: ${model.model.no_heads_single_attn}
transition_n: ${model.model.transition_n}
pair_dropout: ${model.model.pair_dropout}
fuse_projection_weights: ${model.model.fuse_projection_weights}

distogram_head:
c_z: ${model.model.c_pair}
no_bins: 64

# Exponential moving average decay rate
ema_decay: 0.999

globals:
chunk_size: null # 4
# Use DeepSpeed memory-efficient attention kernel in supported modules.
use_deepspeed_evo_attention: true
samples_per_trunk: 48 # Number of diffusion module replicas per trunk
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk
eps: 0.00000001
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores
matmul_precision: "high"

0 comments on commit 712ec7a

Please sign in to comment.