-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3499777
commit 41dd75f
Showing
4 changed files
with
560 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import random | ||
import math | ||
import numpy as np | ||
import time | ||
|
||
import nearest_neighbors.octree as octree | ||
import nearest_neighbors.kdtree as kdtree | ||
from nearest_neighbors.result_set import KNNResultSet, RadiusNNResultSet | ||
|
||
|
||
def main(): | ||
# configuration | ||
db_size = 64000 | ||
dim = 3 | ||
leaf_size = 32 | ||
min_extent = 0.0001 | ||
k = 8 | ||
radius = 1 | ||
|
||
iteration_num = 100 | ||
|
||
print("octree --------------") | ||
construction_time_sum = 0 | ||
knn_time_sum = 0 | ||
radius_time_sum = 0 | ||
brute_time_sum = 0 | ||
for i in range(iteration_num): | ||
db_np = (np.random.rand(db_size, dim) - 0.5) * 100 | ||
|
||
begin_t = time.time() | ||
root = octree.octree_construction(db_np, leaf_size, min_extent) | ||
construction_time_sum += time.time() - begin_t | ||
|
||
query = np.random.rand(3) | ||
|
||
begin_t = time.time() | ||
result_set = KNNResultSet(capacity=k) | ||
octree.octree_knn_search(root, db_np, result_set, query) | ||
knn_time_sum += time.time() - begin_t | ||
|
||
begin_t = time.time() | ||
result_set = RadiusNNResultSet(radius=radius) | ||
octree.octree_radius_search_fast(root, db_np, result_set, query) | ||
radius_time_sum += time.time() - begin_t | ||
|
||
begin_t = time.time() | ||
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1) | ||
nn_idx = np.argsort(diff) | ||
nn_dist = diff[nn_idx] | ||
brute_time_sum += time.time() - begin_t | ||
print("Octree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum*1000/iteration_num, | ||
knn_time_sum*1000/iteration_num, | ||
radius_time_sum*1000/iteration_num, | ||
brute_time_sum*1000/iteration_num)) | ||
|
||
print("kdtree --------------") | ||
construction_time_sum = 0 | ||
knn_time_sum = 0 | ||
radius_time_sum = 0 | ||
brute_time_sum = 0 | ||
for i in range(iteration_num): | ||
db_np = np.random.rand(db_size, dim) | ||
|
||
begin_t = time.time() | ||
root = kdtree.kdtree_construction(db_np, leaf_size) | ||
construction_time_sum += time.time() - begin_t | ||
|
||
query = np.random.rand(3) | ||
|
||
begin_t = time.time() | ||
result_set = KNNResultSet(capacity=k) | ||
kdtree.kdtree_knn_search(root, db_np, result_set, query) | ||
knn_time_sum += time.time() - begin_t | ||
|
||
begin_t = time.time() | ||
result_set = RadiusNNResultSet(radius=radius) | ||
kdtree.kdtree_radius_search(root, db_np, result_set, query) | ||
radius_time_sum += time.time() - begin_t | ||
|
||
begin_t = time.time() | ||
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1) | ||
nn_idx = np.argsort(diff) | ||
nn_dist = diff[nn_idx] | ||
brute_time_sum += time.time() - begin_t | ||
print("Kdtree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum * 1000 / iteration_num, | ||
knn_time_sum * 1000 / iteration_num, | ||
radius_time_sum * 1000 / iteration_num, | ||
brute_time_sum * 1000 / iteration_num)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import random | ||
import math | ||
import numpy as np | ||
|
||
from nearest_neighbors.result_set import KNNResultSet, RadiusNNResultSet | ||
|
||
class Node: | ||
def __init__(self, key, value=-1): | ||
self.left = None | ||
self.right = None | ||
self.key = key | ||
self.value = value | ||
|
||
def __str__(self): | ||
return "key: %s, value: %s" % (str(self.key), str(self.value)) | ||
|
||
|
||
def insert(root, key, value=-1): | ||
if root is None: | ||
root = Node(key, value) | ||
else: | ||
if key < root.key: | ||
root.left = insert(root.left, key, value) | ||
elif key > root.key: | ||
root.right = insert(root.right, key, value) | ||
else: # don't insert if key already exist in the tree | ||
pass | ||
return root | ||
|
||
|
||
def inorder(root): | ||
# Inorder (Left, Root, Right) | ||
if root is not None: | ||
inorder(root.left) | ||
print(root) | ||
inorder(root.right) | ||
|
||
|
||
def preorder(root): | ||
# Preorder (Root, Left, Right) | ||
if root is not None: | ||
print(root) | ||
preorder(root.left) | ||
preorder(root.right) | ||
|
||
|
||
def postorder(root): | ||
# Postorder (Left, Right, Root) | ||
if root is not None: | ||
postorder(root.left) | ||
postorder(root.right) | ||
print(root) | ||
|
||
|
||
def knn_search(root: Node, result_set: KNNResultSet, key): | ||
if root is None: | ||
return False | ||
|
||
# compare the root itself | ||
result_set.add_point(math.fabs(root.key - key), root.value) | ||
if result_set.worstDist() == 0: | ||
return True | ||
|
||
if root.key >= key: | ||
# iterate left branch first | ||
if knn_search(root.left, result_set, key): | ||
return True | ||
elif math.fabs(root.key-key) < result_set.worstDist(): | ||
return knn_search(root.right, result_set, key) | ||
return False | ||
else: | ||
# iterate right branch first | ||
if knn_search(root.right, result_set, key): | ||
return True | ||
elif math.fabs(root.key-key) < result_set.worstDist(): | ||
return knn_search(root.left, result_set, key) | ||
return False | ||
|
||
|
||
def radius_search(root: Node, result_set: RadiusNNResultSet, key): | ||
if root is None: | ||
return False | ||
|
||
# compare the root itself | ||
result_set.add_point(math.fabs(root.key - key), root.value) | ||
|
||
if root.key >= key: | ||
# iterate left branch first | ||
if radius_search(root.left, result_set, key): | ||
return True | ||
elif math.fabs(root.key-key) < result_set.worstDist(): | ||
return radius_search(root.right, result_set, key) | ||
return False | ||
else: | ||
# iterate right branch first | ||
if radius_search(root.right, result_set, key): | ||
return True | ||
elif math.fabs(root.key-key) < result_set.worstDist(): | ||
return radius_search(root.left, result_set, key) | ||
return False | ||
|
||
|
||
def search_recursive(root, key): | ||
if root is None or root.key == key: | ||
return root | ||
if key < root.key: | ||
return search_recursive(root.left, key) | ||
elif key > root.key: | ||
return search_recursive(root.right, key) | ||
|
||
|
||
def search_iterative(root, key): | ||
current_node = root | ||
while current_node is not None: | ||
if current_node.key == key: | ||
return current_node | ||
elif key < current_node.key: | ||
current_node = current_node.left | ||
elif key > current_node.key: | ||
current_node = current_node.right | ||
return current_node | ||
|
||
|
||
def main(): | ||
# configuration | ||
db_size = 100 | ||
k = 5 | ||
radius = 2.0 | ||
|
||
data = np.random.permutation(db_size).tolist() | ||
|
||
root = None | ||
for i, point in enumerate(data): | ||
root = insert(root, point, i) | ||
|
||
query_key = 6 | ||
result_set = KNNResultSet(capacity=k) | ||
knn_search(root, result_set, query_key) | ||
print('kNN Search:') | ||
print('index - distance') | ||
print(result_set) | ||
|
||
result_set = RadiusNNResultSet(radius=radius) | ||
radius_search(root, result_set, query_key) | ||
print('Radius NN Search:') | ||
print('index - distance') | ||
print(result_set) | ||
|
||
|
||
# print("inorder") | ||
# inorder(root) | ||
# print("preorder") | ||
# preorder(root) | ||
# print("postorder") | ||
# postorder(root) | ||
|
||
|
||
|
||
# node = search_recursive(root, 2) | ||
# print(node) | ||
# | ||
# node = search_iterative(root, 2) | ||
# print(node) | ||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.