forked from sokrypton/ColabFold
-
Notifications
You must be signed in to change notification settings - Fork 3
/
download.py
54 lines (44 loc) · 1.94 KB
/
download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import logging
import tarfile
from pathlib import Path
import appdirs
import tqdm
logger = logging.getLogger(__name__)
# The data dir location logic switches between a version with and one without "params" because alphafold
# always internally joins "params". (We should probably patch alphafold)
default_data_dir = Path(appdirs.user_cache_dir(__package__ or "colabfold"))
def download_alphafold_params(model_type: str, data_dir: Path = default_data_dir):
import requests
params_dir = data_dir.joinpath("params")
if model_type == "AlphaFold2-multimer-v2":
url = "https://storage.googleapis.com/alphafold/alphafold_params_colab_2022-03-02.tar"
success_marker = params_dir.joinpath(
"download_complexes_multimer-v2_finished.txt"
)
elif model_type == "AlphaFold2-multimer-v1":
url = "https://storage.googleapis.com/alphafold/alphafold_params_colab_2021-10-27.tar"
success_marker = params_dir.joinpath(
"download_complexes_multimer-v1_finished.txt"
)
else:
url = "https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar"
success_marker = params_dir.joinpath("download_finished.txt")
if success_marker.is_file():
return
params_dir.mkdir(parents=True, exist_ok=True)
response = requests.get(url, stream=True)
file_size = int(response.headers.get("Content-Length", 0))
with tqdm.tqdm.wrapattr(
response.raw,
"read",
total=file_size,
desc=f"Downloading alphafold2 weights to {data_dir}",
) as response_raw:
# Open in stream mode ("r|"), as our requests response doesn't support seeking)
file = tarfile.open(fileobj=response_raw, mode="r|")
file.extractall(path=params_dir)
success_marker.touch()
if __name__ == "__main__":
# TODO: Arg to select which one
download_alphafold_params("AlphaFold2-multimer-v2")
download_alphafold_params("AlphaFold2-ptm")