Skip to content

Commit

Permalink
Merge branch 'master' into add_dude
Browse files Browse the repository at this point in the history
  • Loading branch information
drewnutt authored Nov 19, 2024
2 parents ebb38ed + caa9de3 commit de561d3
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ Links to download pre-trained models are in `checkpoints/README.md`.

Once downloaded, just `gunzip` the file to get the ready-to-use model checkpoint.

# Download MERGED dataset
Script to download splits and data:
```
cd data/MERGED/huge_data/
bash download.sh
cd -
```

# Embed proteins and molecules
```
# Get target embeddings with pre-trained model
Expand Down
5 changes: 3 additions & 2 deletions checkpoints/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Pre-trained models
You can download pretrained models from the google-drive links below. Once downloaded, just `gunzip` the file to get the model checkpoint.
- SaProt Contrastive Model: https://drive.google.com/file/d/1W7qOc4CkiX6NhSj17PRcz8j5iJckdSIw/view?usp=sharing
- ConPLex Model (Retrained): https://drive.google.com/file/d/1Mhv30_q6GIBgN2hab0atLeaK0d3AoNm1/view?usp=sharing
- SPRINT-xs: https://drive.google.com/file/d/1pKct70RIPyXByilstxGE7CWw5CRKnKPf/view?usp=sharing
- SPRINT-sm: https://drive.google.com/file/d/1Fh4XmC5y5kRUV4fQXPdOJGs-UrVIFTuE/view?usp=sharing
- ConPLex Model (Retrained): https://drive.google.com/file/d/1Mhv30\_q6GIBgN2hab0atLeaK0d3AoNm1/view?usp=sharing
6 changes: 2 additions & 4 deletions data/MERGED/huge_data/run_mmseq.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# This script runs MMSeq2 on the LIT-PCBA dataset to only retain uniprot id's with sequence similarity <90% to LIT-PCBA

# TODO: get mmseq results.
# TODO: concat train / val / test for this.

import json
import numpy as np
import subprocess
import os
import sys

def run_mmseqs2(query_file, target_file, output_file, tmp_dir, threshold):
commands = [
Expand Down Expand Up @@ -69,4 +67,4 @@ def main(threshold=0.9):
subprocess.run("rm targetDB*; rm queryDB*; rm resultDB*; rm -rf tmp", shell=True, check=True)

if __name__ == "__main__":
main(threshold=0.9)
main(threshold=float(sys.argv[1]))
3 changes: 2 additions & 1 deletion ultrafast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def __init__(

def forward(self, drug, target):
model_size = self.args.model_size
sigmoid_scalar = self.args.sigmoid_scalar
if model_size == 'huge' or model_size == 'mega':
y = self.drug_projector['proj'](drug)
for layer in self.drug_projector['res']:
Expand All @@ -383,7 +384,7 @@ def forward(self, drug, target):
target_projection = self.target_projector(target)

if self.classify:
similarity = 4 * F.cosine_similarity(
similarity = sigmoid_scalar * F.cosine_similarity(
drug_projection, target_projection
)
else:
Expand Down
16 changes: 13 additions & 3 deletions ultrafast/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def train_cli():
parser.add_argument("--ship-model", help="Train a final to ship model, while excluding the uniprot id's specified by this argument.", dest="ship_model")
parser.add_argument("--eval-pcba", action="store_true", help="Evaluate PCBA during validation")
parser.add_argument("--eval-dude", action="store_true", help="Evaluate DUDe during validation")
parser.add_argument("--sigmoid-scalar", type=int, default=5, dest="sigmoid_scalar")

args = parser.parse_args()
train(**vars(args))
Expand Down Expand Up @@ -124,6 +125,7 @@ def train(
ship_model: str,
eval_pcba: bool,
eval_dude: bool,
sigmoid_scalar: int,
):
args = argparse.Namespace(
experiment_id=experiment_id,
Expand Down Expand Up @@ -156,6 +158,7 @@ def train(
ship_model=ship_model,
eval_pcba=eval_pcba,
eval_dude=eval_dude,
sigmoid_scalar=sigmoid_scalar,
)
config = OmegaConf.load(args.config)
args_overrides = {k: v for k, v in vars(args).items() if v is not None}
Expand Down Expand Up @@ -285,10 +288,12 @@ def train(
wandb_logger.watch(model)
if hasattr(wandb_logger.experiment.config, 'update'):
wandb_logger.experiment.config.update(OmegaConf.to_container(config, resolve=True, throw_on_missing=True))
wandb_logger.experiment.tags = [config.task, config.experiment_id, config.target_featurizer, config.model_size]

wandb_logger.experiment.tags = [config.task, 'arxiv', config.target_featurizer]

if config.task == 'merged':
if config.task == 'merged' and args.ship_model:
# save every epoch
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=-1,
dirpath=save_dir,
Expand Down Expand Up @@ -323,15 +328,20 @@ def train(
)

if ship_model:
# Train on all data
# Train on all data and test with best weights
trainer.fit(model, datamodule=datamodule)
trainer.test(datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path)
# Save the final model
trainer.save_checkpoint(f"{save_dir}/ship_model.ckpt")
else:
# Regular training with validation
trainer.fit(model, datamodule=datamodule)
# Test model using best weights
trainer.test(datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path)
if config.epochs == 0:
ckpt = config.checkpoint
else:
ckpt = checkpoint_callback.best_model_path
trainer.test(datamodule=datamodule, ckpt_path=ckpt)


if __name__ == '__main__':
Expand Down

0 comments on commit de561d3

Please sign in to comment.