forked from Layne-Huang/PMDM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeometry.py
59 lines (49 loc) · 2.05 KB
/
geometry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch_scatter import scatter_add
def get_distance(pos, edge_index):
return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
def eq_transform(score_d, pos, edge_index, edge_length):
N = pos.size(0)
dd_dr = (1. / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)
score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) \
+ scatter_add(- dd_dr * score_d, edge_index[1], dim=0, dim_size=N) # (N, 3)
# score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N)
return score_pos
def convert_cluster_score_d(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length, subgraph_index):
"""
Args:
cluster_score_d: (E_c, 1)
subgraph_index: (N, )
"""
cluster_score_pos = eq_transform(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length) # (C, 3)
score_pos = cluster_score_pos[subgraph_index]
return score_pos
def get_angle(pos, angle_index):
"""
Args:
pos: (N, 3)
angle_index: (3, A), left-center-right.
"""
n1, ctr, n2 = angle_index # (A, )
v1 = pos[n1] - pos[ctr] # (A, 3)
v2 = pos[n2] - pos[ctr]
inner_prod = torch.sum(v1 * v2, dim=-1, keepdim=True) # (A, 1)
length_prod = torch.norm(v1, dim=-1, keepdim=True) * torch.norm(v2, dim=-1, keepdim=True) # (A, 1)
angle = torch.acos(inner_prod / length_prod) # (A, 1)
return angle
def get_dihedral(pos, dihedral_index):
"""
Args:
pos: (N, 3)
dihedral: (4, A)
"""
n1, ctr1, ctr2, n2 = dihedral_index # (A, )
v_ctr = pos[ctr2] - pos[ctr1] # (A, 3)
v1 = pos[n1] - pos[ctr1]
v2 = pos[n2] - pos[ctr2]
n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes
n2 = torch.cross(v_ctr, v2, dim=-1)
inner_prod = torch.sum(n1 * n2, dim=1, keepdim=True) # (A, 1)
length_prod = torch.norm(n1, dim=-1, keepdim=True) * torch.norm(n2, dim=-1, keepdim=True)
dihedral = torch.acos(inner_prod / length_prod)
return dihedral