Skip to content

Commit

Permalink
Add BST and kd-tree.
Browse files Browse the repository at this point in the history
  • Loading branch information
lijx-nutonomy committed Mar 11, 2020
1 parent 3499777 commit 41dd75f
Show file tree
Hide file tree
Showing 4 changed files with 560 additions and 0 deletions.
93 changes: 93 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import random
import math
import numpy as np
import time

import nearest_neighbors.octree as octree
import nearest_neighbors.kdtree as kdtree
from nearest_neighbors.result_set import KNNResultSet, RadiusNNResultSet


def main():
# configuration
db_size = 64000
dim = 3
leaf_size = 32
min_extent = 0.0001
k = 8
radius = 1

iteration_num = 100

print("octree --------------")
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
for i in range(iteration_num):
db_np = (np.random.rand(db_size, dim) - 0.5) * 100

begin_t = time.time()
root = octree.octree_construction(db_np, leaf_size, min_extent)
construction_time_sum += time.time() - begin_t

query = np.random.rand(3)

begin_t = time.time()
result_set = KNNResultSet(capacity=k)
octree.octree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t

begin_t = time.time()
result_set = RadiusNNResultSet(radius=radius)
octree.octree_radius_search_fast(root, db_np, result_set, query)
radius_time_sum += time.time() - begin_t

begin_t = time.time()
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
nn_idx = np.argsort(diff)
nn_dist = diff[nn_idx]
brute_time_sum += time.time() - begin_t
print("Octree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum*1000/iteration_num,
knn_time_sum*1000/iteration_num,
radius_time_sum*1000/iteration_num,
brute_time_sum*1000/iteration_num))

print("kdtree --------------")
construction_time_sum = 0
knn_time_sum = 0
radius_time_sum = 0
brute_time_sum = 0
for i in range(iteration_num):
db_np = np.random.rand(db_size, dim)

begin_t = time.time()
root = kdtree.kdtree_construction(db_np, leaf_size)
construction_time_sum += time.time() - begin_t

query = np.random.rand(3)

begin_t = time.time()
result_set = KNNResultSet(capacity=k)
kdtree.kdtree_knn_search(root, db_np, result_set, query)
knn_time_sum += time.time() - begin_t

begin_t = time.time()
result_set = RadiusNNResultSet(radius=radius)
kdtree.kdtree_radius_search(root, db_np, result_set, query)
radius_time_sum += time.time() - begin_t

begin_t = time.time()
diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
nn_idx = np.argsort(diff)
nn_dist = diff[nn_idx]
brute_time_sum += time.time() - begin_t
print("Kdtree: build %.3f, knn %.3f, radius %.3f, brute %.3f" % (construction_time_sum * 1000 / iteration_num,
knn_time_sum * 1000 / iteration_num,
radius_time_sum * 1000 / iteration_num,
brute_time_sum * 1000 / iteration_num))



if __name__ == '__main__':
main()
170 changes: 170 additions & 0 deletions bst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import random
import math
import numpy as np

from nearest_neighbors.result_set import KNNResultSet, RadiusNNResultSet

class Node:
def __init__(self, key, value=-1):
self.left = None
self.right = None
self.key = key
self.value = value

def __str__(self):
return "key: %s, value: %s" % (str(self.key), str(self.value))


def insert(root, key, value=-1):
if root is None:
root = Node(key, value)
else:
if key < root.key:
root.left = insert(root.left, key, value)
elif key > root.key:
root.right = insert(root.right, key, value)
else: # don't insert if key already exist in the tree
pass
return root


def inorder(root):
# Inorder (Left, Root, Right)
if root is not None:
inorder(root.left)
print(root)
inorder(root.right)


def preorder(root):
# Preorder (Root, Left, Right)
if root is not None:
print(root)
preorder(root.left)
preorder(root.right)


def postorder(root):
# Postorder (Left, Right, Root)
if root is not None:
postorder(root.left)
postorder(root.right)
print(root)


def knn_search(root: Node, result_set: KNNResultSet, key):
if root is None:
return False

# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)
if result_set.worstDist() == 0:
return True

if root.key >= key:
# iterate left branch first
if knn_search(root.left, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return knn_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if knn_search(root.right, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return knn_search(root.left, result_set, key)
return False


def radius_search(root: Node, result_set: RadiusNNResultSet, key):
if root is None:
return False

# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)

if root.key >= key:
# iterate left branch first
if radius_search(root.left, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return radius_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if radius_search(root.right, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return radius_search(root.left, result_set, key)
return False


def search_recursive(root, key):
if root is None or root.key == key:
return root
if key < root.key:
return search_recursive(root.left, key)
elif key > root.key:
return search_recursive(root.right, key)


def search_iterative(root, key):
current_node = root
while current_node is not None:
if current_node.key == key:
return current_node
elif key < current_node.key:
current_node = current_node.left
elif key > current_node.key:
current_node = current_node.right
return current_node


def main():
# configuration
db_size = 100
k = 5
radius = 2.0

data = np.random.permutation(db_size).tolist()

root = None
for i, point in enumerate(data):
root = insert(root, point, i)

query_key = 6
result_set = KNNResultSet(capacity=k)
knn_search(root, result_set, query_key)
print('kNN Search:')
print('index - distance')
print(result_set)

result_set = RadiusNNResultSet(radius=radius)
radius_search(root, result_set, query_key)
print('Radius NN Search:')
print('index - distance')
print(result_set)


# print("inorder")
# inorder(root)
# print("preorder")
# preorder(root)
# print("postorder")
# postorder(root)



# node = search_recursive(root, 2)
# print(node)
#
# node = search_iterative(root, 2)
# print(node)





if __name__ == '__main__':
main()
Loading

0 comments on commit 41dd75f

Please sign in to comment.