Skip to content

Commit ab93340

Browse files
authored
update updated_kmeans.jl
Added speed benchmark tests.
1 parent b47f877 commit ab93340

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

Tutorials/kmeans_implementation_project/updated_kmeans.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
using Pkg
12
# Replace python environment to suit your needs
23
ENV["PYTHON"] = "/home/mysterio/miniconda3/envs/pydata/bin/python"
34
Pkg.build("PyCall") # Build PyCall to suit the specified Python env
45

56
using PyCall
67
using Plots
78
using LinearAlgebra
9+
using Statistics
810
using BenchmarkTools
911
using Distances
1012

@@ -26,12 +28,12 @@ ran_c = randn(ran_k, ran_k)
2628
"""
2729
function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Array, k::Int)
2830
ss = 0
29-
31+
3032
for j = 1:k
3133
group_data = x[findall(labels .== j), :]
3234
group_centroid_matrix = convert(Matrix, centre[j, :]')
3335
group_distance = pairwise(Euclidean(), group_data, group_centroid_matrix, dims=1)
34-
36+
3537
ss += sum(group_distance .^ 2)
3638
end
3739

@@ -56,16 +58,16 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
5658
distances = zeros(n_row)
5759

5860
J_previous = Inf64
59-
61+
6062
# Update centroids & labels with closest members until convergence
6163
for iter = 1:max_iters
6264
nearest_neighbour = pairwise(Euclidean(), design_matrix, centroids, dims=1)
63-
65+
6466
min_val_idx = findmin.(eachrow(nearest_neighbour))
6567

6668
distances = [x[1] for x in min_val_idx]
6769
labels = [x[2] for x in min_val_idx]
68-
70+
6971
centroids = [ mean( X[findall(labels .== j), : ], dims = 1) for j = 1:k]
7072
centroids = reduce(vcat, centroids)
7173

@@ -75,7 +77,7 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
7577
# Show progress and terminate if J stopped decreasing.
7678
println("Iteration ", iter, ": Jclust = ", J, ".")
7779
end;
78-
80+
7981
# Final Step 5: Check for convergence
8082
if iter > 1 && abs(J - J_previous) < (tol * J)
8183
# TODO: Calculate the sum of squares
@@ -84,9 +86,9 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
8486
if verbose
8587
println("Successfully terminated with convergence.")
8688
end
87-
89+
8890
return labels, centroids, sum_squares
89-
91+
9092
elseif iter == max_iters && abs(J - J_previous) > (tol * J)
9193
throw(error("Failed to converge Check data and/or implementation or increase max_iter."))
9294
end;
@@ -100,3 +102,17 @@ end
100102
Kmeans(X, 3)
101103

102104

105+
@btime begin
106+
num = []
107+
ss = []
108+
for i = 2:10
109+
l, c, s = Kmeans(X, i, verbose=false)
110+
push!(num, i)
111+
push!(ss, s)
112+
end
113+
end
114+
115+
116+
plot(num, ss, ylabel="Sum of Squares", xlabel="Number of Iterations",
117+
title = "Test For Heterogeneity Per Iteration", legend=false)
118+

0 commit comments

Comments
 (0)