Skip to content

Commit 7b38d9e

Browse files
authored
Update updated_kmeans.jl
Currently testing and benchmarking an optimized working implementation of kmeans++ (and random).
1 parent 9ced44a commit 7b38d9e

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

Tutorials/kmeans_implementation_project/updated_kmeans.jl

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,60 @@ using PyCall
77
using Plots
88
using LinearAlgebra
99
using Statistics
10+
using StatsBase
1011
using BenchmarkTools
1112
using Distances
1213

1314

14-
1515
# import sklearn datasets
1616
data = pyimport("sklearn.datasets")
17-
1817
X, y = data.make_blobs(n_samples=1000000, n_features=3, centers=3, cluster_std=0.9, random_state=80)
1918

2019

21-
ran_k = 3
22-
ran_x = randn(100, ran_k)
23-
ran_l = rand(1:ran_k, 100)
24-
ran_c = randn(ran_k, ran_k)
20+
"""
21+
"""
22+
function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
23+
n_row, n_col = size(X)
24+
25+
if init == "k-means++"
26+
# randonmly select the first centroid from the data (X)
27+
centroids = zeros(k, n_col)
28+
rand_idx = rand(1:n_row)
29+
centroids[1, :] = X[rand_idx, :]
30+
31+
# compute distances from the first centroid chosen to all the other data points
32+
first_centroid_matrix = convert(Matrix, centroids[1, :]')
33+
# flattened vector (n_row,)
34+
distances = vec(pairwise(Euclidean(), X, first_centroid_matrix, dims = 1))
35+
36+
for i = 2:k
37+
# choose the next centroid, the probability for each data point to be chosen
38+
# is directly proportional to its squared distance from the nearest centroid
39+
prob = distances .^ 2
40+
r_idx = sample(1:n_row, ProbabilityWeights(prob))
41+
centroids[i, :] = X[r_idx, :]
42+
43+
if i == (k-1)
44+
break
45+
end
46+
47+
# compute distances from the centroids to all data points
48+
# and update the squared distance as the minimum distance to all centroid
49+
current_centroid_matrix = convert(Matrix, centroids[i, :]')
50+
new_distances = vec(pairwise(Euclidean(), X, current_centroid_matrix, dims = 1))
51+
52+
distances = minimum([distances, new_distances])
53+
54+
end
55+
56+
else
57+
rand_indices = rand(1:n_row, k)
58+
centroids = X[rand_indices, :]
59+
60+
end
61+
62+
return centroids, n_row, n_col
63+
end
2564

2665

2766
"""
@@ -41,19 +80,13 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
4180
end
4281

4382

44-
sum_of_squares(ran_x, ran_l, ran_c, ran_k)
45-
46-
47-
48-
4983
"""
5084
"""
51-
function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300, tol=1e-5, verbose::Bool=true)
85+
function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; k_init::String="k-means++",
86+
max_iters::Int64=300, tol=1e-4, verbose::Bool=true)
87+
88+
centroids, n_row, n_col = smart_init(design_matrix, k, init=k_init)
5289

53-
# randomly get centroids for each group
54-
n_row, n_col = size(design_matrix)
55-
rand_indices = rand(1:n_row, k)
56-
centroids = design_matrix[rand_indices, :]
5790
labels = rand(1:k, n_row)
5891
distances = zeros(n_row)
5992

@@ -80,7 +113,7 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
80113

81114
# Final Step 5: Check for convergence
82115
if iter > 1 && abs(J - J_previous) < (tol * J)
83-
# TODO: Calculate the sum of squares
116+
84117
sum_squares = sum_of_squares(design_matrix, labels, centroids, k)
85118
# Terminate algorithm with the assumption that K-means has converged
86119
if verbose
@@ -99,14 +132,11 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
99132
end
100133

101134

102-
Kmeans(X, 3)
103-
104-
105135
@btime begin
106136
num = []
107137
ss = []
108138
for i = 2:10
109-
l, c, s = Kmeans(X, i, verbose=false)
139+
l, c, s = Kmeans(X, i, k_init="k-means++", verbose=false)
110140
push!(num, i)
111141
push!(ss, s)
112142
end
@@ -116,3 +146,12 @@ end
116146
plot(num, ss, ylabel="Sum of Squares", xlabel="Number of Iterations",
117147
title = "Test For Heterogeneity Per Iteration", legend=false)
118148

149+
150+
function test_speed(x)
151+
for i = 2:10
152+
l, c, s = Kmeans(X, i, k_init="k-means++", verbose=false)
153+
end
154+
end
155+
156+
r = @benchmark test_speed(X) samples=7 seconds=300
157+

0 commit comments

Comments
 (0)