-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgraph_construction.py
27 lines (26 loc) · 1.01 KB
/
graph_construction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
from scipy.spatial import distance
def calcADJ(coord, k=4, distanceType='euclidean', pruneTag='NA'):
spatialMatrix = coord
nodes = spatialMatrix.shape[0]
Adj = torch.zeros((nodes, nodes))
for i in np.arange(spatialMatrix.shape[0]):
tmp = spatialMatrix[i, :].reshape(1, -1)
distMat = distance.cdist(tmp, spatialMatrix, distanceType)
if k == 0:
k = spatialMatrix.shape[0] - 1
res = distMat.argsort()[:k + 1]
tmpdist = distMat[0, res[0][1:k + 1]]
boundary = np.mean(tmpdist) + np.std(tmpdist)
for j in np.arange(1, k + 1):
# No prune
if pruneTag == 'NA':
Adj[i][res[0][j]] = 1.0
elif pruneTag == 'STD':
if distMat[0, res[0][j]] <= boundary:
Adj[i][res[0][j]] = 1.0
elif pruneTag == 'Grid':
if distMat[0, res[0][j]] <= 2.0:
Adj[i][res[0][j]] = 1.0
return Adj