Skip to content

Commit

Permalink
faster is knn from
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq committed Jul 12, 2024
1 parent d8c4108 commit 60c874b
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions python/polars_ds/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,26 +423,16 @@ def is_knn_from(
raise ValueError("Dimension does not match.")

if dist == "l1":
return (
pl.sum_horizontal(
(e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth)
).rank(method="min")
<= k
)
dist = pl.sum_horizontal((e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth))
return dist <= dist.bottom_k(k=k).max()
elif dist == "l2":
return (
pl.sum_horizontal(
(e - pl.lit(xi, dtype=pl.Float64)).pow(2) for xi, e in zip(pt, oth)
).rank(method="min")
<= k
dist = pl.sum_horizontal(
(e - pl.lit(xi, dtype=pl.Float64)).pow(2) for xi, e in zip(pt, oth)
)
return dist <= dist.bottom_k(k=k).max()
elif dist == "inf":
return (
pl.max_horizontal(
(e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth)
).rank(method="min")
<= k
)
dist = pl.max_horizontal((e - pl.lit(xi, dtype=pl.Float64)).abs() for xi, e in zip(pt, oth))
return dist <= dist.bottom_k(k=k).max()
elif dist == "cosine":
x_list = list(pt)
x_norm = sum(z * z for z in x_list)
Expand All @@ -451,7 +441,7 @@ def is_knn_from(
1.0
- pl.sum_horizontal(xi * e for xi, e in zip(x_list, oth)) / (x_norm * oth_norm).sqrt()
)
return dist.rank(method="min") <= k
return dist <= dist.bottom_k(k=k).max()
elif dist in ("h", "haversine"):
pt_as_list = list(pt)
if (len(pt_as_list) != 2) or (len(oth) < 2):
Expand All @@ -463,7 +453,7 @@ def is_knn_from(
y_lat = pl.lit(pt_as_list[0], dtype=pl.Float64)
y_long = pl.lit(pt_as_list[1], dtype=pl.Float64)
dist = haversine(oth[0], oth[1], y_lat, y_long)
return dist.rank(method="min") <= k
return dist <= dist.bottom_k(k=k).max()
else:
raise ValueError(f"Unknown distance function: {dist}")

Expand Down

0 comments on commit 60c874b

Please sign in to comment.