forked from Ligo-Biosciences/AlphaFold3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
ardagoreci
committed
Aug 18, 2024
1 parent
f547bf8
commit 712ec7a
Showing
1 changed file
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# AlphaFold3 configs | ||
# _target_: src.models.model_wrapper.AlphaFoldWrapper | ||
# config: | ||
optimizer: | ||
_target_: torch.optim.Adam # deepspeed.ops.adam.FusedAdam | ||
_partial_: true | ||
lr: 0.00018 | ||
betas: | ||
- 0.9 | ||
- 0.95 | ||
eps: 1e-08 | ||
weight_decay: 0.0 | ||
|
||
scheduler: | ||
_target_: src.utils.lr_schedulers.AlphaFoldLRScheduler | ||
_partial_: true | ||
last_epoch: -1 | ||
verbose: false | ||
base_lr: 0.0 # the starting learning rate | ||
max_lr: 0.00018 | ||
warmup_no_steps: 1000 | ||
start_decay_after_n_steps: 50_000 | ||
decay_every_n_steps: 50_000 | ||
decay_factor: 0.95 | ||
|
||
# Loss configs | ||
loss: | ||
diffusion_loss: | ||
# mse: | ||
# weight_dna: 5.0 | ||
# weight_rna: 5.0 | ||
# weight_ligand: 10.0 | ||
# weight_protein: 1.0 | ||
# smooth_lddt: | ||
# weight: 1.0 | ||
sd_data: 16.0 | ||
weight: 4.0 | ||
smooth_lddt_loss: | ||
weight: 1.0 | ||
epsilon: 1e-5 | ||
|
||
distogram: | ||
min_bin: 0.0 | ||
max_bin: 32.0 | ||
no_bins: 64 | ||
eps: 0.000006 # 1e-6 | ||
weight: 0.03 | ||
|
||
experimentally_resolved: | ||
eps: 0.00000001 # 1e-8, | ||
# min_resolution: 0.1, | ||
# max_resolution: 3.0, | ||
weight: 0.0004 | ||
|
||
plddt_loss: | ||
min_resolution: 0.1 | ||
max_resolution: 3.0 | ||
cutoff: 15.0 | ||
no_bins": 50 | ||
eps: 0.0000000001 # 1e-10, | ||
weight: 0.0004 | ||
|
||
|
||
# TODO: fix the model.model notation in interpolation | ||
model: | ||
c_token: 384 # the token representation dim | ||
c_pair: 128 # the pair representation dim | ||
c_atom: 128 # the atom representation dim | ||
c_atompair: 16 # the atom pair representation dim | ||
|
||
# Pair stack parameters (used in Pairformer, MSA module, and Confidence head) | ||
c_hidden_tri_mul: 128 # the hidden dim for the triangle multiplicative update | ||
c_hidden_pair_attn: 32 # the hidden dim for the pair attention ${common.c_hidden_pair_attn} | ||
no_heads_tri_attn: 4 | ||
transition_n: 4 | ||
pair_dropout: 0.25 | ||
fuse_projection_weights: false | ||
blocks_per_ckpt: 1 # number of blocks per checkpoint, if none, no checkpointing | ||
clear_cache_between_blocks: false # whether to clear GPU memory cache between blocks | ||
# Pairformer attention pair bias | ||
no_heads_single_attn: 16 | ||
|
||
# Input Embedder | ||
input_embedder: | ||
c_token: ${model.model.c_token} | ||
c_trunk_pair: ${model.model.c_pair} | ||
c_atom: ${model.model.c_atom} | ||
c_atompair: ${model.model.c_atompair} | ||
|
||
# MSA module | ||
msa_module: | ||
no_blocks: 4 | ||
c_msa: 64 | ||
c_token: ${model.model.c_token} | ||
c_z: ${model.model.c_pair} | ||
c_hidden: 32 | ||
no_heads: 8 | ||
c_hidden_tri_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
inf: 1e8 | ||
|
||
# Template Embedder | ||
template_embedder: | ||
no_blocks: 2 | ||
c_template: 64 | ||
c_z: ${model.model.c_pair} | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
|
||
# PairformerStack | ||
pairformer_stack: | ||
c_s: ${model.model.c_token} | ||
c_z: ${model.model.c_pair} | ||
no_blocks: 24 | ||
c_hidden_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
no_heads_single_attn: ${model.model.no_heads_single_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
clear_cache_between_blocks: false | ||
inf: 1e8 | ||
|
||
# Diffusion module | ||
diffusion_module: | ||
c_atom: ${model.model.c_atom} | ||
c_atompair: ${model.model.c_atompair} | ||
c_token: ${model.model.c_token} | ||
c_tokenpair: ${model.model.c_pair} | ||
atom_encoder_blocks: 3 | ||
atom_encoder_heads: 16 | ||
dropout: 0.0 | ||
atom_attention_n_queries: 32 # TODO: with sliding window attention this is not used. | ||
atom_attention_n_keys: 128 | ||
atom_decoder_blocks: 3 | ||
atom_decoder_heads: 16 | ||
token_transformer_blocks: 12 | ||
token_transformer_heads: 16 | ||
sd_data: 16.0 | ||
s_max: 160.0 | ||
s_min: 0.0004 | ||
p: 7.0 | ||
clear_cache_between_blocks: ${model.model.clear_cache_between_blocks} | ||
blocks_per_ckpt: ${model.model.blocks_per_ckpt} | ||
compile_module: false # TODO: this parameter does not belong here, will be removed in future iterations. | ||
|
||
confidence_head: | ||
c_s: 384 # ${model.c_token} | ||
c_z: ${model.model.c_pair} | ||
no_blocks: 4 | ||
no_bins_pde: 64 | ||
no_bins_plddt: 64 | ||
no_bins_pae: 64 | ||
c_hidden_mul: ${model.model.c_hidden_tri_mul} | ||
c_hidden_pair_attn: ${model.model.c_hidden_pair_attn} | ||
no_heads_tri_attn: ${model.model.no_heads_tri_attn} | ||
no_heads_single_attn: ${model.model.no_heads_single_attn} | ||
transition_n: ${model.model.transition_n} | ||
pair_dropout: ${model.model.pair_dropout} | ||
fuse_projection_weights: ${model.model.fuse_projection_weights} | ||
|
||
distogram_head: | ||
c_z: ${model.model.c_pair} | ||
no_bins: 64 | ||
|
||
# Exponential moving average decay rate | ||
ema_decay: 0.999 | ||
|
||
globals: | ||
chunk_size: null # 4 | ||
# Use DeepSpeed memory-efficient attention kernel in supported modules. | ||
use_deepspeed_evo_attention: true | ||
samples_per_trunk: 48 # Number of diffusion module replicas per trunk | ||
rollout_samples_per_trunk: 1 # Number of mini rollouts per trunk | ||
eps: 0.00000001 | ||
# internal precision of float32 matrix multiplications. "high" or "medium" will utilize Tensor cores | ||
matmul_precision: "high" |