Skip to content

Commit

Permalink
- improved the node splitting method
Browse files Browse the repository at this point in the history
- Updated the runtimes comparison versus sklearn
- fixed little typos in the readme
- fixed typechecks in the trees
  • Loading branch information
rdbs-oss committed Mar 31, 2021
1 parent f1f4b17 commit a30aa66
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 28 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,3 @@ The (minor) slow down observed against sklearn implementation is probably relate
</p>

Query time are somehow identical. However, my implementation does seem to not scale as well as scikit-learn's one, a minor slowdown could be observed for extremely large datasets (million-ish data points).

### Warnings
Because input data needs to be typed: the dimensionality of the process is fixed in advance. This BallTree implementation can not work on 3D and above data (although it is a one-liner fix).
Binary file modified assets/images/building_timings.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/images/query_timings.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 5 additions & 2 deletions lsnms/balltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(self, data, leaf_size=16, indices=None):
# Stores the data
self.data = data

if len(self.data) == 0:
raise ValueError('Empty data')

# Stores indices of each data point
if indices is None:
self.indices = np.arange(len(data))
Expand Down Expand Up @@ -139,8 +142,8 @@ def query_radius(self, X, max_radius):
"""
if X.ndim > 1:
raise ValueError("query_radius only works on single query point.")
if len(X) != 2:
raise ValueError("Query point must be two-dimensional")
if X.shape[-1] != self.dimensionality:
raise ValueError("Tree and query dimensionality do not match")
# Initialize empty list of int64
# Needs to be typed
buffer = [0][:0]
Expand Down
7 changes: 5 additions & 2 deletions lsnms/kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(self, data, leaf_size=16, axis=0, indices=None):
self.axis = axis
self.dimensionality = data.shape[-1]

if len(self.data) == 0:
raise ValueError('Empty data')

# Stores indices of each data point
if indices is None:
self.indices = np.arange(len(data))
Expand Down Expand Up @@ -145,8 +148,8 @@ def query_radius(self, X, max_radius):
"""
if X.ndim > 1:
raise ValueError("query_radius only works on single query point.")
if len(X) != 2:
raise ValueError("Query point must be two-dimensional")
if X.shape[-1] != self.dimensionality:
raise ValueError("Tree and query dimensionality do not match")
# Initialize empty list of int64
# Needs to be typed
buffer = [0][:0]
Expand Down
111 changes: 98 additions & 13 deletions lsnms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def max_spread_axis(data):
def split_along_axis(data, axis):
"""
Splits the data along axis in two datasets of equal size.
Note that this could probably be optimized further, by implementing the median algorithm from
scratch.
This method uses an adapted re-implementation of `np.argpartition`
Parameters
----------
Expand All @@ -157,17 +156,7 @@ def split_along_axis(data, axis):
Tuple[np.array]
Left data point indices, right data point indices
"""
indices = np.arange(len(data))
cap = np.median(data[:, axis])
mask = data[:, axis] <= cap
n_left = mask.sum()
# Account for the case where all positions along this axis are equal: split in the middle
if n_left == len(data) or n_left == 0:
left = indices[: len(indices) // 2]
right = indices[len(indices) // 2 :]
else:
left = indices[mask]
right = indices[np.logical_not(mask)]
left, right = median_argsplit(data[:, axis])
return left, right


Expand Down Expand Up @@ -217,3 +206,99 @@ def englobing_box(data):
bounds.insert(j, data[:, j].min())
bounds.insert(2 * j + 1, data[:, j].max())
return np.array(bounds)


@njit
def _partition(A, low, high, indices):
"""
This is straight from numba master:
https://github.com/numba/numba/blob/b5bd9c618e20985acb0b300d52d57595ef6f5442/numba/np/arraymath.py#L1155
I modified it so the swaps operate on the indices as well, because I need a argpartition
"""
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.
# Use median of three {low, middle, high} as the pivot
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
indices[low], indices[mid] = indices[mid], indices[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
indices[high], indices[mid] = indices[mid], indices[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
indices[low], indices[mid] = indices[mid], indices[low]
pivot = A[mid]

A[high], A[mid] = A[mid], A[high]
indices[high], indices[mid] = indices[mid], indices[high]
i = low
j = high - 1
while True:
while i < high and A[i] < pivot:
i += 1
while j >= low and pivot < A[j]:
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
indices[i], indices[j] = indices[j], indices[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
# print(A)
A[i], A[high] = A[high], A[i]
indices[i], indices[high] = indices[high], indices[i]

return i


@njit
def _select(arry, k, low, high):
"""
This is straight from numba master:
https://github.com/numba/numba/blob/b5bd9c618e20985acb0b300d52d57595ef6f5442/numba/np/arraymath.py#L1155
Select the k'th smallest element in array[low:high + 1].
"""
indices = np.arange(len(arry))
i = _partition(arry, low, high, indices)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high, indices)
else:
high = i - 1
i = _partition(arry, low, high, indices)
return indices, i


@njit
def median_argsplit(arry):
"""
Splits `arry` into two sets of indices, indicating values
above and below the pivot value. Often, pivot is the median.
This is approx. three folds faster than computing the median,
then find indices of values below (left indices) and above (right indices)
Parameters
----------
arry : np.array
One dimensional values array
Returns
-------
Tuple[np.array]
Indices of values below median, indices of values above median
"""
low = 0
high = len(arry) - 1
k = len(arry) >> 1
tmp_arry = arry.flatten()
indices, i = _select(tmp_arry, k, low, high)
left = indices[:k]
right = indices[k:]
return left, right
10 changes: 6 additions & 4 deletions tests/timings_balltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def test_tree_query_timing():
ns = np.arange(1000, 200000, 10000)
ts = []
naive_ts = []
leaf_size = 64
repeats = 100
for n in ns:
data = np.random.uniform(0, 1000, (n, 2))
sk_tree = skBT(data, leaf_size=16)
tree = BallTree(data, leaf_size=16)
sk_tree = skBT(data, leaf_size=leaf_size)
tree = BallTree(data, leaf_size=int(leaf_size * 0.67))
_ = tree.query_radius(data[0], 200.0)
timer = Timer(lambda: tree.query_radius(data[0], 100.0))
ts.append(timer.timeit(number=repeats) / repeats * 1000)
Expand All @@ -42,13 +43,14 @@ def test_tree_building_timing():

ns = np.arange(1000, 300000, 25000)
ts = []
leaf_size = 64
naive_ts = []
for n in ns:
data = np.random.uniform(0, n, (n, 2))
_ = BallTree(data, 16)
timer = Timer(lambda: BallTree(data, 16))
timer = Timer(lambda: BallTree(data, leaf_size))
ts.append(timer.timeit(number=5) / 5)
naive_timer = Timer(lambda: skBT(data, 16))
naive_timer = Timer(lambda: skBT(data, int(leaf_size * 0.67)))
naive_ts.append(naive_timer.timeit(5) / 5)

with plt.xkcd():
Expand Down
10 changes: 6 additions & 4 deletions tests/timings_kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def test_tree_query_timing():

ns = np.arange(1000, 200000, 10000)
ts = []
leaf_size = 64
naive_ts = []
repeats = 100
for n in ns:
data = np.random.uniform(0, 1000, (n, 2))
sk_tree = skKDT(data, leaf_size=16)
tree = KDTree(data, leaf_size=16)
sk_tree = skKDT(data, leaf_size=int(leaf_size * 0.67))
tree = KDTree(data, leaf_size=leaf_size)
_ = tree.query_radius(data[0], 200.0)
timer = Timer(lambda: tree.query_radius(data[0], 100.0))
ts.append(timer.timeit(number=repeats) / repeats * 1000)
Expand All @@ -130,14 +131,15 @@ def test_tree_query_timing():

def test_tree_building_timing():
ns = np.arange(1000, 300000, 25000)
leaf_size = 64
ts = []
naive_ts = []
for n in ns:
data = np.random.uniform(0, n, (n, 2))
_ = KDTree(data, 16)
timer = Timer(lambda: KDTree(data, 16))
timer = Timer(lambda: KDTree(data, leaf_size))
ts.append(timer.timeit(number=5) / 5)
naive_timer = Timer(lambda: skKDT(data, 16))
naive_timer = Timer(lambda: skKDT(data, int(leaf_size * 0.67)))
naive_ts.append(naive_timer.timeit(5) / 5)

with plt.xkcd():
Expand Down

0 comments on commit a30aa66

Please sign in to comment.