diff --git a/source/pacmap/pacmap.py b/source/pacmap/pacmap.py index d2ad615..3476026 100644 --- a/source/pacmap/pacmap.py +++ b/source/pacmap/pacmap.py @@ -1115,6 +1115,7 @@ def sample_FP_nearby(n_samples, maximum, reject_ind, self_ind, Y, low_dist_thres result = np.empty(n_samples, dtype=np.int32) for i in range(n_samples): reject_sample = True + count = 0 while reject_sample: j = np.random.randint(maximum) if j == self_ind: @@ -1126,15 +1127,20 @@ def sample_FP_nearby(n_samples, maximum, reject_ind, self_ind, Y, low_dist_thres for k in range(reject_ind.shape[0]): if j == reject_ind[k]: break - if euclid_dist(Y[self_ind], Y[j]) > low_dist_thres: - continue else: - reject_sample = False + if euclid_dist(Y[self_ind], Y[j]) > low_dist_thres: + continue + else: + reject_sample = False + count += 1 + if count > 100: + j = -1 + reject_sample = False result[i] = j return result -@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4,i4, f4[:,:], f4)", parallel=True, nogil=True, cache=True) -def sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thres): +@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4[:,:],i4,i4, f4[:,:], f4)", parallel=True, nogil=True, cache=True) +def sample_FP_pair_nearby(X, pair_neighbors, old_pair_FP, n_neighbors, n_FP, Y, low_dist_thres): '''Resample Further pairs for local graph adjustment''' n = X.shape[0] pair_FP = np.empty((n * n_FP, 2), dtype=np.int32) @@ -1149,7 +1155,10 @@ def sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thre ) for k in numba.prange(n_FP): pair_FP[i*n_FP + k][0] = i - pair_FP[i*n_FP + k][1] = FP_index[k] + if FP_index[k] == -1: + pair_FP[i*n_FP + k][1] = old_pair_FP[i*n_FP + k][1] + else: + pair_FP[i*n_FP + k][1] = FP_index[k] return pair_FP @numba.njit("f4[:,:](f4[:,:],i4[:,:],i4[:,:],i4[:,:],f4,f4,f4,f4)", parallel=True, nogil=True, cache=True) @@ -1281,7 +1290,7 @@ def localmap( update_embedding_adam(Y, grad, m, v, beta1, beta2, lr, itr) if (itr > num_iters[0] + num_iters[1]) and (itr % 10 == 0): - pair_FP = sample_FP_pair_nearby(X, pair_neighbors, n_neighbors, n_FP, Y, low_dist_thres) + pair_FP = sample_FP_pair_nearby(X, pair_neighbors, pair_FP, n_neighbors, n_FP, Y, low_dist_thres) if intermediate: if (itr + 1) == inter_snapshots[itr_ind]: @@ -1462,7 +1471,4 @@ def fit(self, X, init=None, save_pairs=True): ) if not save_pairs: self.del_pairs() - return self - - - \ No newline at end of file + return self \ No newline at end of file