-
Notifications
You must be signed in to change notification settings - Fork 272
/
sampling.py
269 lines (220 loc) · 15.4 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import copy
import random
import numpy as np
import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch
from utils.torsion import modify_conformer_torsion_angles
from scipy.spatial.transform import Rotation as R
from utils.utils import crop_beyond
from utils.logging_utils import get_logger
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7,
initial_noise_std_proportion=-1.0, choose_residue=False):
# in place modification of the list
center_pocket = data_list[0]['receptor'].pos.mean(dim=0)
if pocket_knowledge:
complex = data_list[0]
d = torch.cdist(complex['receptor'].pos, torch.from_numpy(complex['ligand'].orig_pos[0]).float() - complex.original_center)
label = torch.any(d < pocket_cutoff, dim=1)
if torch.any(label):
center_pocket = complex['receptor'].pos[label].mean(dim=0)
else:
print("No pocket residue below minimum distance ", pocket_cutoff, "taking closest at", torch.min(d))
center_pocket = complex['receptor'].pos[torch.argmin(torch.min(d, dim=1)[0])]
if not no_torsion:
# randomize torsion angles
for complex_graph in data_list:
torsion_updates = np.random.uniform(low=-np.pi, high=np.pi, size=complex_graph['ligand'].edge_mask.sum())
complex_graph['ligand'].pos = \
modify_conformer_torsion_angles(complex_graph['ligand'].pos,
complex_graph['ligand', 'ligand'].edge_index.T[
complex_graph['ligand'].edge_mask],
complex_graph['ligand'].mask_rotate[0], torsion_updates)
for complex_graph in data_list:
# randomize position
molecule_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True)
random_rotation = torch.from_numpy(R.random().as_matrix()).float()
complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T + center_pocket
# base_rmsd = np.sqrt(np.sum((complex_graph['ligand'].pos.cpu().numpy() - orig_complex_graph['ligand'].pos.numpy()) ** 2, axis=1).mean())
if not no_random: # note for now the torsion angles are still randomised
if choose_residue:
idx = random.randint(0, len(complex_graph['receptor'].pos)-1)
tr_update = torch.normal(mean=complex_graph['receptor'].pos[idx:idx+1], std=0.01)
elif initial_noise_std_proportion >= 0.0:
std_rec = torch.sqrt(torch.mean(torch.sum(complex_graph['receptor'].pos ** 2, dim=1)))
tr_update = torch.normal(mean=0, std=std_rec * initial_noise_std_proportion / 1.73, size=(1, 3))
else:
# if initial_noise_std_proportion < 0.0, we use the tr_sigma_max multiplied by -initial_noise_std_proportion
tr_update = torch.normal(mean=0, std=-initial_noise_std_proportion * tr_sigma_max, size=(1, 3))
complex_graph['ligand'].pos += tr_update
def is_iterable(arr):
try:
some_object_iterator = iter(arr)
return True
except TypeError as te:
return False
def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_schedule, device, t_to_sigma, model_args,
no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None, confidence_model_args=None,
t_schedule=None, batch_size=32, no_final_step_noise=False, pivot=None, return_full_trajectory=False,
temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False):
N = len(data_list)
trajectory = []
logger = get_logger()
if return_features:
lig_features, rec_features = [], []
assert batch_size >= N, "Not implemented yet"
loader = DataLoader(data_list, batch_size=batch_size)
assert not (return_full_trajectory or return_features or pivot), "Not implemented yet in new inference version"
mask_rotate = torch.from_numpy(data_list[0]['ligand'].mask_rotate[0]).to(device)
confidence = None
if confidence_model is not None:
confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size))
confidence = []
with torch.no_grad():
for batch_id, complex_graph_batch in enumerate(loader):
b = complex_graph_batch.num_graphs
n = len(complex_graph_batch['ligand'].pos) // b
complex_graph_batch = complex_graph_batch.to(device)
for t_idx in range(inference_steps):
t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx]
dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx]
dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx]
dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx]
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor)
if hasattr(model_args, 'crop_beyond') and model_args.crop_beyond is not None:
#print('Cropping beyond', tr_sigma * 3 + model_args.crop_beyond, 'for score model')
mod_complex_graph_batch = copy.deepcopy(complex_graph_batch).to_data_list()
for batch in mod_complex_graph_batch:
crop_beyond(batch, tr_sigma * 3 + model_args.crop_beyond, model_args.all_atoms)
mod_complex_graph_batch = Batch.from_data_list(mod_complex_graph_batch)
else:
mod_complex_graph_batch = complex_graph_batch
set_time(mod_complex_graph_batch, t_schedule[t_idx] if t_schedule is not None else None, t_tr, t_rot, t_tor, b,
'all_atoms' in model_args and model_args.all_atoms, device)
tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3]
mean_scores = torch.mean(tr_score, dim=-1)
num_nans = torch.sum(torch.isnan(mean_scores))
if num_nans > 0:
name = complex_graph_batch['name']
if isinstance(name, list):
name = name[0]
logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: "
f"{num_nans} / {mean_scores.numel()} samples failed")
# Set the nan values to a small value, just want to disturb slightly
# Hopefully won't get nan the next iteration
tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps)
rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps)
tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps)
del eps
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
if ode:
tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score)
rot_perturb = (0.5 * rot_score * dt_rot * rot_g ** 2)
else:
tr_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z)
rot_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
rot_perturb = (rot_score * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z)
if not model_args.no_torsion:
tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min)))
if ode:
tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score)
else:
tor_z = torch.zeros(tor_score.shape, device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=tor_score.shape, device=device)
tor_perturb = (tor_g ** 2 * dt_tor * tor_score + tor_g * np.sqrt(dt_tor) * tor_z)
torsions_per_molecule = tor_perturb.shape[0] // b
else:
tor_perturb = None
if not is_iterable(temp_sampling):
temp_sampling = [temp_sampling] * 3
if not is_iterable(temp_psi):
temp_psi = [temp_psi] * 3
if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3
if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3
if not is_iterable(temp_sigma_data): temp_sigma_data = [temp_sigma_data] * 3
assert len(temp_sampling) == 3
assert len(temp_psi) == 3
assert len(temp_sigma_data) == 3
if temp_sampling[0] != 1.0:
tr_sigma_data = np.exp(temp_sigma_data[0] * np.log(model_args.tr_sigma_max) + (1 - temp_sigma_data[0]) * np.log(model_args.tr_sigma_min))
lambda_tr = (tr_sigma_data + tr_sigma) / (tr_sigma_data + tr_sigma / temp_sampling[0])
tr_perturb = (tr_g ** 2 * dt_tr * (lambda_tr + temp_sampling[0] * temp_psi[0] / 2) * tr_score + tr_g * np.sqrt(dt_tr * (1 + temp_psi[0])) * tr_z)
if temp_sampling[1] != 1.0:
rot_sigma_data = np.exp(temp_sigma_data[1] * np.log(model_args.rot_sigma_max) + (1 - temp_sigma_data[1]) * np.log(model_args.rot_sigma_min))
lambda_rot = (rot_sigma_data + rot_sigma) / (rot_sigma_data + rot_sigma / temp_sampling[1])
rot_perturb = (rot_g ** 2 * dt_rot * (lambda_rot + temp_sampling[1] * temp_psi[1] / 2) * rot_score + rot_g * np.sqrt(dt_rot * (1 + temp_psi[1])) * rot_z)
if temp_sampling[2] != 1.0:
tor_sigma_data = np.exp(temp_sigma_data[2] * np.log(model_args.tor_sigma_max) + (1 - temp_sigma_data[2]) * np.log(model_args.tor_sigma_min))
lambda_tor = (tor_sigma_data + tor_sigma) / (tor_sigma_data + tor_sigma / temp_sampling[2])
tor_perturb = (tor_g ** 2 * dt_tor * (lambda_tor + temp_sampling[2] * temp_psi[2] / 2) * tor_score + tor_g * np.sqrt(dt_tor * (1 + temp_psi[2])) * tor_z)
# Apply noise
complex_graph_batch['ligand'].pos = \
modify_conformer_batch(complex_graph_batch['ligand'].pos, complex_graph_batch, tr_perturb, rot_perturb,
tor_perturb if not model_args.no_torsion else None, mask_rotate)
if visualization_list is not None:
for idx_b in range(b):
visualization_list[batch_id * batch_size + idx_b].add((
complex_graph_batch['ligand'].pos[idx_b*n:n*(idx_b+1)].detach().cpu() +
data_list[batch_id * batch_size + idx_b].original_center.detach().cpu()),
part=1, order=t_idx + 2)
for i in range(b):
data_list[batch_id * batch_size + i]['ligand'].pos = complex_graph_batch['ligand'].pos[i*n:n*(i+1)]
if visualization_list is not None:
for idx, visualization in enumerate(visualization_list):
visualization.add((data_list[idx]['ligand'].pos.detach().cpu() + data_list[idx].original_center.detach().cpu()),
part=1, order=2)
if confidence_model is not None:
if confidence_data_list is not None:
confidence_complex_graph_batch = next(confidence_loader)
confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos.cpu()
if hasattr(confidence_model_args, 'crop_beyond') and confidence_model_args.crop_beyond is not None:
confidence_complex_graph_batch = confidence_complex_graph_batch.to_data_list()
for batch in confidence_complex_graph_batch:
crop_beyond(batch, confidence_model_args.crop_beyond, confidence_model_args.all_atoms)
confidence_complex_graph_batch = Batch.from_data_list(confidence_complex_graph_batch)
confidence_complex_graph_batch = confidence_complex_graph_batch.to(device)
set_time(confidence_complex_graph_batch, 0, 0, 0, 0, b, confidence_model_args.all_atoms, device)
out = confidence_model(confidence_complex_graph_batch)
else:
out = confidence_model(complex_graph_batch)
if type(out) is tuple:
out = out[0]
confidence.append(out)
if confidence_model is not None:
confidence = torch.cat(confidence, dim=0)
confidence = torch.nan_to_num(confidence, nan=-1000)
if return_full_trajectory:
return data_list, confidence, trajectory
elif return_features:
lig_features = torch.cat(lig_features, dim=0)
rec_features = torch.cat(rec_features, dim=0)
return data_list, confidence, lig_features, rec_features
return data_list, confidence
def compute_affinity(data_list, affinity_model, affinity_data_list, device, parallel, all_atoms, include_miscellaneous_atoms):
with torch.no_grad():
if affinity_model is not None:
assert parallel <= len(data_list)
loader = DataLoader(data_list, batch_size=parallel)
complex_graph_batch = next(iter(loader)).to(device)
positions = complex_graph_batch['ligand'].pos
assert affinity_data_list is not None
complex_graph = affinity_data_list[0]
N = complex_graph['ligand'].num_nodes
complex_graph['ligand'].x = complex_graph['ligand'].x.repeat(parallel, 1)
complex_graph['ligand'].edge_mask = complex_graph['ligand'].edge_mask.repeat(parallel)
complex_graph['ligand', 'ligand'].edge_index = torch.cat(
[N * i + complex_graph['ligand', 'ligand'].edge_index for i in range(parallel)], dim=1)
complex_graph['ligand', 'ligand'].edge_attr = complex_graph['ligand', 'ligand'].edge_attr.repeat(parallel, 1)
complex_graph['ligand'].pos = positions
affinity_loader = DataLoader([complex_graph], batch_size=1)
affinity_batch = next(iter(affinity_loader)).to(device)
set_time(affinity_batch, 0, 0, 0, 0, 1, all_atoms, device, include_miscellaneous_atoms=include_miscellaneous_atoms)
_, affinity = affinity_model(affinity_batch)
else:
affinity = None
return affinity