@@ -7,21 +7,60 @@ using PyCall
7
7
using Plots
8
8
using LinearAlgebra
9
9
using Statistics
10
+ using StatsBase
10
11
using BenchmarkTools
11
12
using Distances
12
13
13
14
14
-
15
15
# import sklearn datasets
16
16
data = pyimport (" sklearn.datasets" )
17
-
18
17
X, y = data. make_blobs (n_samples= 1000000 , n_features= 3 , centers= 3 , cluster_std= 0.9 , random_state= 80 )
19
18
20
19
21
- ran_k = 3
22
- ran_x = randn (100 , ran_k)
23
- ran_l = rand (1 : ran_k, 100 )
24
- ran_c = randn (ran_k, ran_k)
20
+ """
21
+ """
22
+ function smart_init (X:: Array{Float64, 2} , k:: Int ; init:: String = " k-means++" )
23
+ n_row, n_col = size (X)
24
+
25
+ if init == " k-means++"
26
+ # randonmly select the first centroid from the data (X)
27
+ centroids = zeros (k, n_col)
28
+ rand_idx = rand (1 : n_row)
29
+ centroids[1 , :] = X[rand_idx, :]
30
+
31
+ # compute distances from the first centroid chosen to all the other data points
32
+ first_centroid_matrix = convert (Matrix, centroids[1 , :]' )
33
+ # flattened vector (n_row,)
34
+ distances = vec (pairwise (Euclidean (), X, first_centroid_matrix, dims = 1 ))
35
+
36
+ for i = 2 : k
37
+ # choose the next centroid, the probability for each data point to be chosen
38
+ # is directly proportional to its squared distance from the nearest centroid
39
+ prob = distances .^ 2
40
+ r_idx = sample (1 : n_row, ProbabilityWeights (prob))
41
+ centroids[i, :] = X[r_idx, :]
42
+
43
+ if i == (k- 1 )
44
+ break
45
+ end
46
+
47
+ # compute distances from the centroids to all data points
48
+ # and update the squared distance as the minimum distance to all centroid
49
+ current_centroid_matrix = convert (Matrix, centroids[i, :]' )
50
+ new_distances = vec (pairwise (Euclidean (), X, current_centroid_matrix, dims = 1 ))
51
+
52
+ distances = minimum ([distances, new_distances])
53
+
54
+ end
55
+
56
+ else
57
+ rand_indices = rand (1 : n_row, k)
58
+ centroids = X[rand_indices, :]
59
+
60
+ end
61
+
62
+ return centroids, n_row, n_col
63
+ end
25
64
26
65
27
66
"""
@@ -41,19 +80,13 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
41
80
end
42
81
43
82
44
- sum_of_squares (ran_x, ran_l, ran_c, ran_k)
45
-
46
-
47
-
48
-
49
83
"""
50
84
"""
51
- function Kmeans (design_matrix:: Array{Float64, 2} , k:: Int64 ; max_iters:: Int64 = 300 , tol= 1e-5 , verbose:: Bool = true )
85
+ function Kmeans (design_matrix:: Array{Float64, 2} , k:: Int64 ; k_init:: String = " k-means++" ,
86
+ max_iters:: Int64 = 300 , tol= 1e-4 , verbose:: Bool = true )
87
+
88
+ centroids, n_row, n_col = smart_init (design_matrix, k, init= k_init)
52
89
53
- # randomly get centroids for each group
54
- n_row, n_col = size (design_matrix)
55
- rand_indices = rand (1 : n_row, k)
56
- centroids = design_matrix[rand_indices, :]
57
90
labels = rand (1 : k, n_row)
58
91
distances = zeros (n_row)
59
92
@@ -80,7 +113,7 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
80
113
81
114
# Final Step 5: Check for convergence
82
115
if iter > 1 && abs (J - J_previous) < (tol * J)
83
- # TODO : Calculate the sum of squares
116
+
84
117
sum_squares = sum_of_squares (design_matrix, labels, centroids, k)
85
118
# Terminate algorithm with the assumption that K-means has converged
86
119
if verbose
@@ -99,14 +132,11 @@ function Kmeans(design_matrix::Array{Float64, 2}, k::Int64; max_iters::Int64=300
99
132
end
100
133
101
134
102
- Kmeans (X, 3 )
103
-
104
-
105
135
@btime begin
106
136
num = []
107
137
ss = []
108
138
for i = 2 : 10
109
- l, c, s = Kmeans (X, i, verbose= false )
139
+ l, c, s = Kmeans (X, i, k_init = " k-means++ " , verbose= false )
110
140
push! (num, i)
111
141
push! (ss, s)
112
142
end
116
146
plot (num, ss, ylabel= " Sum of Squares" , xlabel= " Number of Iterations" ,
117
147
title = " Test For Heterogeneity Per Iteration" , legend= false )
118
148
149
+
150
+ function test_speed (x)
151
+ for i = 2 : 10
152
+ l, c, s = Kmeans (X, i, k_init= " k-means++" , verbose= false )
153
+ end
154
+ end
155
+
156
+ r = @benchmark test_speed (X) samples= 7 seconds= 300
157
+
0 commit comments