forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[example] Add fastfold tutorial (hpcaitech#2528)
* add fastfold example * pre-commit polish * pre-commit polish readme and add empty test ci * Add test_ci and reduce the default sequence length
- Loading branch information
Showing
3 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# FastFold Inference | ||
|
||
## Table of contents | ||
|
||
- [Overview](#📚-overview) | ||
- [Quick Start](#🚀-quick-start) | ||
- [Dive into FastFold](#🔍-dive-into-fastfold) | ||
|
||
## 📚 Overview | ||
|
||
This example lets you to quickly try out the inference of FastFold. | ||
|
||
**NOTE: We use random data and random parameters in this example.** | ||
|
||
|
||
## 🚀 Quick Start | ||
|
||
1. Install FastFold | ||
|
||
We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda. | ||
|
||
``` | ||
git clone https://github.com/hpcaitech/FastFold | ||
cd FastFold | ||
conda env create --name=fastfold -f environment.yml | ||
conda activate fastfold | ||
python setup.py install | ||
``` | ||
|
||
2. Run the inference scripts. | ||
|
||
```bash | ||
python inference.py --gpus=1 --n_res=256 --chunk_size=None --inplace | ||
``` | ||
+ `gpus` means the DAP size | ||
+ `n_res` means the length of residue sequence | ||
+ `chunk_size` introduces a memory-saving technology at the cost of speed, None means not using, 16 may be a good trade off for long sequences. | ||
+ `inplace` introduces another memory-saving technology with zero cost, drop `--inplace` if you do not want it. | ||
|
||
## 🔍 Dive into FastFold | ||
|
||
There are another features of FastFold, such as: | ||
+ more excellent kernel based on triton | ||
+ much faster data processing based on ray | ||
+ training supported | ||
|
||
More detailed information can be seen [here](https://github.com/hpcaitech/FastFold/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2023 HPC-AI Tech Inc. | ||
# Copyright 2021 AlQuraishi Laboratory | ||
# Copyright 2021 DeepMind Technologies Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import time | ||
|
||
import fastfold | ||
import numpy as np | ||
import torch | ||
import torch.multiprocessing as mp | ||
from fastfold.config import model_config | ||
from fastfold.data import data_transforms | ||
from fastfold.model.fastnn import set_chunk_size | ||
from fastfold.model.hub import AlphaFold | ||
from fastfold.utils.inject_fastnn import inject_fastnn | ||
from fastfold.utils.tensor_utils import tensor_tree_map | ||
|
||
if int(torch.__version__.split(".")[0]) >= 1 and int(torch.__version__.split(".")[1]) > 11: | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
|
||
|
||
def random_template_feats(n_templ, n): | ||
b = [] | ||
batch = { | ||
"template_mask": np.random.randint(0, 2, (*b, n_templ)), | ||
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)), | ||
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3), | ||
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)), | ||
"template_all_atom_mask": np.random.randint(0, 2, (*b, n_templ, n, 37)), | ||
"template_all_atom_positions": np.random.rand(*b, n_templ, n, 37, 3) * 10, | ||
"template_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2), | ||
"template_alt_torsion_angles_sin_cos": np.random.rand(*b, n_templ, n, 7, 2), | ||
"template_torsion_angles_mask": np.random.rand(*b, n_templ, n, 7), | ||
} | ||
batch = {k: v.astype(np.float32) for k, v in batch.items()} | ||
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) | ||
return batch | ||
|
||
|
||
def random_extra_msa_feats(n_extra, n): | ||
b = [] | ||
batch = { | ||
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(np.int64), | ||
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32), | ||
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(np.float32), | ||
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(np.float32), | ||
} | ||
return batch | ||
|
||
|
||
def generate_batch(n_res): | ||
batch = {} | ||
tf = torch.randint(21, size=(n_res,)) | ||
batch["target_feat"] = torch.nn.functional.one_hot(tf, 22).float() | ||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) | ||
batch["residue_index"] = torch.arange(n_res) | ||
batch["msa_feat"] = torch.rand((128, n_res, 49)) | ||
t_feats = random_template_feats(4, n_res) | ||
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) | ||
extra_feats = random_extra_msa_feats(5120, n_res) | ||
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) | ||
batch["msa_mask"] = torch.randint(low=0, high=2, size=(128, n_res)).float() | ||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() | ||
batch.update(data_transforms.make_atom14_masks(batch)) | ||
batch["no_recycling_iters"] = torch.tensor(2.) | ||
|
||
add_recycling_dims = lambda t: (t.unsqueeze(-1).expand(*t.shape, 3)) | ||
batch = tensor_tree_map(add_recycling_dims, batch) | ||
|
||
return batch | ||
|
||
|
||
def inference_model(rank, world_size, result_q, batch, args): | ||
os.environ['RANK'] = str(rank) | ||
os.environ['LOCAL_RANK'] = str(rank) | ||
os.environ['WORLD_SIZE'] = str(world_size) | ||
# init distributed for Dynamic Axial Parallelism | ||
fastfold.distributed.init_dap() | ||
torch.cuda.set_device(rank) | ||
config = model_config(args.model_name) | ||
if args.chunk_size: | ||
config.globals.chunk_size = args.chunk_size | ||
|
||
config.globals.inplace = args.inplace | ||
config.globals.is_multimer = False | ||
model = AlphaFold(config) | ||
|
||
model = inject_fastnn(model) | ||
model = model.eval() | ||
model = model.cuda() | ||
|
||
set_chunk_size(model.globals.chunk_size) | ||
|
||
with torch.no_grad(): | ||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} | ||
t = time.perf_counter() | ||
out = model(batch) | ||
print(f"Inference time: {time.perf_counter() - t}") | ||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) | ||
|
||
result_q.put(out) | ||
|
||
torch.distributed.barrier() | ||
torch.cuda.synchronize() | ||
|
||
|
||
def inference_monomer_model(args): | ||
batch = generate_batch(args.n_res) | ||
manager = mp.Manager() | ||
result_q = manager.Queue() | ||
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args)) | ||
out = result_q.get() | ||
|
||
# get unrelexed pdb and save | ||
# batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) | ||
# plddt = out["plddt"] | ||
# plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1) | ||
# unrelaxed_protein = protein.from_prediction(features=batch, | ||
# result=out, | ||
# b_factors=plddt_b_factors) | ||
# with open('demo_unrelex.pdb', 'w+') as fp: | ||
# fp.write(unrelaxed_protein) | ||
|
||
|
||
def main(args): | ||
inference_monomer_model(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--gpus", type=int, default=1, help="""Number of GPUs with which to run inference""") | ||
parser.add_argument("--n_res", type=int, default=50, help="virtual residue number of random data") | ||
parser.add_argument("--model_name", type=str, default="model_1", help="model name of alphafold") | ||
parser.add_argument('--chunk_size', type=int, default=None) | ||
parser.add_argument('--inplace', default=False, action='store_true') | ||
|
||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
set -euxo pipefail | ||
|
||
git clone https://github.com/hpcaitech/FastFold | ||
cd FastFold | ||
pip install -r requirements/requirements.txt | ||
python setup.py install | ||
pip install -r requirements/test_requirements.txt | ||
cd .. | ||
|
||
python inference.py |