forked from google-deepmind/alphafold3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfetch_databases.py
127 lines (107 loc) · 4.15 KB
/
fetch_databases.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Downloads the AlphaFold v3.0 databases from GCS and decompresses them.
Curl is used to download the files and Zstandard (zstd) is used to decompress
them. Make sure both are installed on your system before running this script.
"""
import argparse
import concurrent.futures
import functools
import os
import subprocess
import sys
DATABASE_FILES = (
'bfd-first_non_consensus_sequences.fasta.zst',
'mgy_clusters_2022_05.fa.zst',
'nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst',
'pdb_2022_09_28_mmcif_files.tar.zst',
'pdb_seqres_2022_09_28.fasta.zst',
'rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst',
'rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst',
'uniprot_all_2021_04.fa.zst',
'uniref90_2022_05.fa.zst',
)
BUCKET_PATH = 'https://storage.googleapis.com/alphafold-databases/v3.0'
def download_and_decompress(
filename: str, *, bucket_path: str, download_destination: str
) -> None:
"""Downloads and decompresses a ztsd-compressed file."""
print(
f'STARTING download {filename} from {bucket_path} to'
f' {download_destination}'
)
# Continue (`continue-at -`) for resumability of a partially downloaded file.
# --progress-bar is used to show some progress in the terminal.
# tr '\r' '\n' is used to remove the \r characters which are used by curl to
# updated the progress bar, which can be confusing when multiple calls are
# made at once.
subprocess.run(
args=(
'curl',
'--progress-bar',
*('--continue-at', '-'),
*('--output', f'{download_destination}/{filename}'),
f'{bucket_path}/{filename}',
*('--stderr', '/dev/stdout'),
),
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
# Same as text=True in Python 3.7+, used for backwards compatibility.
universal_newlines=True,
)
print(
f'FINISHED downloading {filename} from {bucket_path} to'
f' {download_destination}.'
)
print(f'STARTING decompressing of {filename}')
# The original compressed file is kept so that if it is interupted it can be
# resumed, skipping the need to download the file again.
subprocess.run(
['zstd', '--decompress', '--force', f'{download_destination}/{filename}'],
check=True,
)
print(f'FINISHED decompressing of {filename}')
def main(argv=('',)) -> None:
"""Main function."""
parser = argparse.ArgumentParser(description='Downloads AlphaFold databases.')
parser.add_argument(
'--download_destination',
default='/srv/alphafold3_data/public_databases',
help='The directory to download the databases to.',
)
args = parser.parse_args(argv)
if os.geteuid() != 0 and args.download_destination.startswith('/srv'):
raise ValueError(
'You must run this script as root to be able to write to /srv.'
)
destination = os.path.expanduser(args.download_destination)
print(f'Downloading all data to: {destination}')
os.makedirs(destination, exist_ok=True)
# Download each of the files and decompress them in parallel.
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(DATABASE_FILES)
) as pool:
any(
pool.map(
functools.partial(
download_and_decompress,
bucket_path=BUCKET_PATH,
download_destination=destination,
),
DATABASE_FILES,
)
)
# Delete all zstd files at the end (after successfully decompressing them).
for filename in DATABASE_FILES:
os.remove(f'{args.download_destination}/{filename}')
print('All databases have been downloaded and decompressed.')
if __name__ == '__main__':
main(sys.argv[1:])