Skip to content

Commit

Permalink
kmeans, to be continued
Browse files Browse the repository at this point in the history
  • Loading branch information
ShaharMS committed Oct 13, 2024
1 parent 35b72d5 commit 9b1de4c
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
- **Added `ColorChannel.hx` enum**
- **Added `Color.fromFloat()`**

## `vision.tools`

- **Fixed `ArrayTools` `min`/`max` methods**

# 2.0.0

## Breaking!
Expand Down
46 changes: 46 additions & 0 deletions src/vision/algorithms/KMeans.hx
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package vision.algorithms;

import vision.exceptions.Unimplemented;
using vision.tools.MathTools;
using vision.tools.ArrayTools;
class KMeans {
public static function generateClustersUsingConvergence<T>(values:Array<T>, clusterAmount:Int, distanceFunction:(T, T) -> Float, conversionFunction:T -> Float):Array<Array<T>> {
var clusterCenters = pickElementsAtRandom(values, clusterAmount, true);

// We don't use clusterAmount in case where the image doesnt have enough distinct colors to satisfy
// the requested amount
var clusters = [for (i in 0...clusterCenters.length) new Array<T>()];

var converged = false;
while (!converged) {
for (i in 0... clusters.length) clusters[i] = [];

for (value in values) {
var distancesToClusterCenters = [for (j in 0...clusterCenters.length) distanceFunction(value, clusterCenters[j])];
var smallestDistanceIndex = distancesToClusterCenters.indexOf(distancesToClusterCenters.min());
clusters[smallestDistanceIndex].push(value);
}

var newClusterCenters = [for (array in clusters) array.average(conversionFunction)];

}

return clusters;
}


public static function pickElementsAtRandom<T>(values:Array<T>, amount:Int, distinct:Bool = false):Array<T> {
if (!distinct) return [for (i in 0...amount) values[(Math.random() * values.length).round()]];

var result:Array<T> = [];
while (result.length < amount && values.length != 0) {
var value = values[(Math.random() * values.length).round()];
if (result.contains(value)) {
values.remove(value);
continue;
}
result.push(value);
}
return result;
}
}
92 changes: 76 additions & 16 deletions src/vision/tools/ArrayTools.hx
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package vision.tools;

import haxe.ds.ArraySort;
import vision.algorithms.Radix;

using vision.tools.MathTools;

import vision.tools.MathTools.*;

class ArrayTools {

/**
Takes a 2D array and flattens it to a regular, 1D array.
@param array
Expand Down Expand Up @@ -41,7 +42,7 @@ class ArrayTools {
@param predicate A function that takes an element and returns true if the element should be used as a delimiter.
@return An array of one higer dimension.
**/
overload extern inline public static function raise<T>(array:Array<T>, predicateOpensArray:Bool, predicate:T -> Bool):Array<Array<T>> {
overload extern inline public static function raise<T>(array:Array<T>, predicateOpensArray:Bool, predicate:T->Bool):Array<Array<T>> {
var raised:Array<Array<T>> = [];
var temp:Array<T> = [];

Expand All @@ -57,51 +58,110 @@ class ArrayTools {
if (temp.length > 0) raised.push(temp);
return raised;
}
public static inline function min<T:Int, #if !cs Uint, #end Int64, Float>(values:Array<T>):T {
var min:T = values[0];

public overload extern static inline function min<T:Int, #if !cs UInt, #end Int64>(values:Array<T>):T {
var min = values[0];
for (i in 0...values.length) {
if (values[i] < min)
if (values[i] < min) min = values[i];
}
return min;
}

public overload extern static inline function min(values:Array<Float>):Float {
var min = values[0];
for (i in 0...values.length) {
if (values[i] < min) min = values[i];
}
return min;
}

public overload extern static inline function min<T>(values:Array<T>, valueFunction:T->Float):T {
var min = values[0];
var minValue = valueFunction(min);
for (i in 0...values.length) {
var currentValue = valueFunction(values[i]);
if (currentValue < minValue) {
min = values[i];
minValue = currentValue;
}
}

return min;
}

public static inline function max<T:Int, #if !cs Uint, #end Int64, Float>(values:Array<T>):T {
var max:T = values[0];
public overload extern static inline function max<T:Int, #if !cs UInt, #end Int64>(values:Array<T>):T {
var max = values[0];
for (i in 0...values.length) {
if (values[i] > max) max = values[i];
}
return max;
}

public overload extern static inline function max(values:Array<Float>):Float {
var max = values[0];
for (i in 0...values.length) {
if (values[i] > max) max = values[i];
}
return max;
}

public overload extern static inline function max<T>(values:Array<T>, valueFunction:T->Float):T {
var max = values[0];
var maxValue = valueFunction(max);
for (i in 0...values.length) {
if (values[i] > max)
var currentValue = valueFunction(values[i]);
if (currentValue > maxValue) {
max = values[i];
maxValue = currentValue;
}
}

return max;
}

public static inline function average<T:Int, #if !cs Uint, #end Int64, Float>(values:Array<T>):StdTypes.Float {
public overload extern static inline function average<T:Int, #if !cs UInt, #end Int64>(values:Array<T>):Float {
var sum = 0.;
for (v in values) {
sum += v;
}
return sum / values.length;
}


public overload extern static inline function average(values:Array<Float>):Float {
var sum = 0.;
for (v in values) {
sum += v;
}
return sum / values.length;
}

public overload extern static inline function average<T>(values:Array<T>, valueFunction:T->Float):Float {
var sum = 0.;
for (v in values) {
sum += valueFunction(v);
}
return sum / values.length;
}

/**
Gets the median of the given values. For large arrays, Radix sort is used to boost performance (5000 elements or above)
Gets the median of the given values. For large arrays, Radix sort is used to boost performance (5000 elements or above)
**/
extern overload public static inline function median<T:Int, #if !cs UInt, #end Int64>(values:Array<T>):T {
if (values.length > 5000) {
return Radix.sort(values.copy())[floor(values.length / 2)];
}
var s = values.copy();
ArraySort.sort(s , (a, b) -> a - b);
ArraySort.sort(s, (a, b) -> a - b);
return s[floor(values.length / 2)];
}

/**
Gets the median of the given values.
Gets the median of the given values.
**/
extern overload public static inline function median(values:Array<Float>) {
var s = values.copy();
ArraySort.sort(s , (a, b) -> Std.int(a - b));
ArraySort.sort(s, (a, b) -> Std.int(a - b));
return s[floor(values.length / 2)];
}
}

}

0 comments on commit 9b1de4c

Please sign in to comment.