forked from lijx10/NN-Trees
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bst.py
170 lines (131 loc) · 4.14 KB
/
bst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import random
import math
import numpy as np
from nearest_neighbors.result_set import KNNResultSet, RadiusNNResultSet
class Node:
def __init__(self, key, value=-1):
self.left = None
self.right = None
self.key = key
self.value = value
def __str__(self):
return "key: %s, value: %s" % (str(self.key), str(self.value))
def insert(root, key, value=-1):
if root is None:
root = Node(key, value)
else:
if key < root.key:
root.left = insert(root.left, key, value)
elif key > root.key:
root.right = insert(root.right, key, value)
else: # don't insert if key already exist in the tree
pass
return root
def inorder(root):
# Inorder (Left, Root, Right)
if root is not None:
inorder(root.left)
print(root)
inorder(root.right)
def preorder(root):
# Preorder (Root, Left, Right)
if root is not None:
print(root)
preorder(root.left)
preorder(root.right)
def postorder(root):
# Postorder (Left, Right, Root)
if root is not None:
postorder(root.left)
postorder(root.right)
print(root)
def knn_search(root: Node, result_set: KNNResultSet, key):
if root is None:
return False
# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)
if result_set.worstDist() == 0:
return True
if root.key >= key:
# iterate left branch first
if knn_search(root.left, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return knn_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if knn_search(root.right, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return knn_search(root.left, result_set, key)
return False
def radius_search(root: Node, result_set: RadiusNNResultSet, key):
if root is None:
return False
# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)
if root.key >= key:
# iterate left branch first
if radius_search(root.left, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return radius_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if radius_search(root.right, result_set, key):
return True
elif math.fabs(root.key-key) < result_set.worstDist():
return radius_search(root.left, result_set, key)
return False
def search_recursive(root, key):
if root is None or root.key == key:
return root
if key < root.key:
return search_recursive(root.left, key)
elif key > root.key:
return search_recursive(root.right, key)
def search_iterative(root, key):
current_node = root
while current_node is not None:
if current_node.key == key:
return current_node
elif key < current_node.key:
current_node = current_node.left
elif key > current_node.key:
current_node = current_node.right
return current_node
def main():
# configuration
db_size = 100
k = 5
radius = 2.0
data = np.random.permutation(db_size).tolist()
root = None
for i, point in enumerate(data):
root = insert(root, point, i)
query_key = 6
result_set = KNNResultSet(capacity=k)
knn_search(root, result_set, query_key)
print('kNN Search:')
print('index - distance')
print(result_set)
result_set = RadiusNNResultSet(radius=radius)
radius_search(root, result_set, query_key)
print('Radius NN Search:')
print('index - distance')
print(result_set)
# print("inorder")
# inorder(root)
# print("preorder")
# preorder(root)
# print("postorder")
# postorder(root)
# node = search_recursive(root, 2)
# print(node)
#
# node = search_iterative(root, 2)
# print(node)
if __name__ == '__main__':
main()