Skip to content

Commit

Permalink
Merge pull request #90 from williamsyy/localmap
Browse files Browse the repository at this point in the history
fix: Fixing the Issue #89
  • Loading branch information
hyhuang00 authored Dec 24, 2024
2 parents 9c2640c + fe9a7be commit 9e6bc47
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions source/pacmap/pacmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,7 @@ def sample_FP_nearby(n_samples, maximum, reject_ind, self_ind, Y, low_dist_thres
result = np.empty(n_samples, dtype=np.int32)
for i in range(n_samples):
reject_sample = True
count = 0
while reject_sample:
j = np.random.randint(maximum)
if j == self_ind:
Expand All @@ -1126,15 +1127,20 @@ def sample_FP_nearby(n_samples, maximum, reject_ind, self_ind, Y, low_dist_thres
for k in range(reject_ind.shape[0]):
if j == reject_ind[k]:
break
if euclid_dist(Y[self_ind], Y[j]) > low_dist_thres:
continue
else:
reject_sample = False
if euclid_dist(Y[self_ind], Y[j]) > low_dist_thres:
continue
else:
reject_sample = False
count += 1
if count > 100:
j = -1
reject_sample = False
result[i] = j
return result

@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4,i4, f4[:,:], f4)", parallel=True, nogil=True, cache=True)
def sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thres):
@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4[:,:],i4,i4, f4[:,:], f4)", parallel=True, nogil=True, cache=True)
def sample_FP_pair_nearby(X, pair_neighbors, old_pair_FP, n_neighbors, n_FP, Y, low_dist_thres):
'''Resample Further pairs for local graph adjustment'''
n = X.shape[0]
pair_FP = np.empty((n * n_FP, 2), dtype=np.int32)
Expand All @@ -1149,7 +1155,10 @@ def sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thre
)
for k in numba.prange(n_FP):
pair_FP[i*n_FP + k][0] = i
pair_FP[i*n_FP + k][1] = FP_index[k]
if FP_index[k] == -1:
pair_FP[i*n_FP + k][1] = old_pair_FP[i*n_FP + k][1]
else:
pair_FP[i*n_FP + k][1] = FP_index[k]
return pair_FP

@numba.njit("f4[:,:](f4[:,:],i4[:,:],i4[:,:],i4[:,:],f4,f4,f4,f4)", parallel=True, nogil=True, cache=True)
Expand Down Expand Up @@ -1281,7 +1290,7 @@ def localmap(
update_embedding_adam(Y, grad, m, v, beta1, beta2, lr, itr)

if (itr > num_iters[0] + num_iters[1]) and (itr % 10 == 0):
pair_FP = sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thres)
pair_FP = sample_FP_pair_nearby(X, pair_neighbors, pair_FP, n_neighbors, n_FP, Y, low_dist_thres)

if intermediate:
if (itr + 1) == inter_snapshots[itr_ind]:
Expand Down Expand Up @@ -1462,7 +1471,4 @@ def fit(self, X, init=None, save_pairs=True):
)
if not save_pairs:
self.del_pairs()
return self



return self

0 comments on commit 9e6bc47

Please sign in to comment.