Skip to content

Commit c3527a3

Browse files
holdenkpwendell
authored andcommitted
SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to MLUtils & fixes bug in BernoulliSampler]
Author: Holden Karau <[email protected]> Closes apache#18 from holdenk/addkfoldcrossvalidation and squashes the following commits: 208db9b [Holden Karau] Fix a bad space e84f2fc [Holden Karau] Fix the test, we should be looking at the second element instead 6ddbf05 [Holden Karau] swap training and validation order 7157ae9 [Holden Karau] CR feedback 90896c7 [Holden Karau] New line 150889c [Holden Karau] Fix up error messages in the MLUtilsSuite 2cb90b3 [Holden Karau] Fix the names in kFold c702a96 [Holden Karau] Fix imports in MLUtils e187e35 [Holden Karau] Move { up to same line as whenExecuting(random) in RandomSamplerSuite.scala c5b723f [Holden Karau] clean up 7ebe4d5 [Holden Karau] CR feedback, remove unecessary learners (came back during merge mistake) and insert an empty line bb5fa56 [Holden Karau] extra line sadness 163c5b1 [Holden Karau] code review feedback 1.to -> 1 to and folds -> numFolds 5a33f1d [Holden Karau] Code review follow up. e8741a7 [Holden Karau] CR feedback b78804e [Holden Karau] Remove cross validation [TODO in another pull request] 91eae64 [Holden Karau] Consolidate things in mlutils 264502a [Holden Karau] Add a test for the bug that was found with BernoulliSampler not copying the complement param dd0b737 [Holden Karau] Wrap long lines (oops) c0b7fa4 [Holden Karau] Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD 08f8e4d [Holden Karau] Fix BernoulliSampler to respect complement a751ec6 [Holden Karau] Add k-fold cross validation to MLLib
1 parent 9edd887 commit c3527a3

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
6969
}
7070
}
7171

72-
override def clone = new BernoulliSampler[T](lb, ub)
72+
/**
73+
* Return a sampler with is the complement of the range specified of the current sampler.
74+
*/
75+
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
76+
77+
override def clone = new BernoulliSampler[T](lb, ub, complement)
7378
}
7479

7580
/**

core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala

+16-8
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,31 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
4141
random.nextDouble().andReturn(x)
4242
}
4343
}
44-
whenExecuting(random)
45-
{
44+
whenExecuting(random) {
4645
val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
4746
assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
4847
}
4948
}
5049

50+
test("BernoulliSamplerWithRangeInverse") {
51+
expecting {
52+
for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
53+
random.nextDouble().andReturn(x)
54+
}
55+
}
56+
whenExecuting(random) {
57+
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
58+
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
59+
}
60+
}
61+
5162
test("BernoulliSamplerWithRatio") {
5263
expecting {
5364
for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
5465
random.nextDouble().andReturn(x)
5566
}
5667
}
57-
whenExecuting(random)
58-
{
68+
whenExecuting(random) {
5969
val sampler = new BernoulliSampler[Int](0.35)(random)
6070
assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
6171
}
@@ -67,8 +77,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
6777
random.nextDouble().andReturn(x)
6878
}
6979
}
70-
whenExecuting(random)
71-
{
80+
whenExecuting(random) {
7281
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
7382
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
7483
}
@@ -78,8 +87,7 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
7887
expecting {
7988
random.setSeed(10L)
8089
}
81-
whenExecuting(random)
82-
{
90+
whenExecuting(random) {
8391
val sampler = new BernoulliSampler[Int](0.2)(random)
8492
sampler.setSeed(10L)
8593
}

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

+21
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@
1717

1818
package org.apache.spark.mllib.util
1919

20+
import scala.reflect.ClassTag
21+
2022
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
2123

2224
import org.apache.spark.annotation.Experimental
2325
import org.apache.spark.SparkContext
2426
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.rdd.PartitionwiseSampledRDD
28+
import org.apache.spark.SparkContext._
29+
import org.apache.spark.util.random.BernoulliSampler
2530
import org.apache.spark.mllib.regression.LabeledPoint
2631
import org.apache.spark.mllib.linalg.Vectors
2732

@@ -157,6 +162,22 @@ object MLUtils {
157162
dataStr.saveAsTextFile(dir)
158163
}
159164

165+
/**
166+
* Return a k element array of pairs of RDDs with the first element of each pair
167+
* containing the training data, a complement of the validation data and the second
168+
* element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
169+
*/
170+
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
171+
val numFoldsF = numFolds.toFloat
172+
(1 to numFolds).map { fold =>
173+
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
174+
complement = false)
175+
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
176+
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
177+
(training, validation)
178+
}.toArray
179+
}
180+
160181
/**
161182
* Returns the squared Euclidean distance between two vectors. The following formula will be used
162183
* if it does not introduce too much numerical error:

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

+39
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ package org.apache.spark.mllib.util
1919

2020
import java.io.File
2121

22+
import scala.math
23+
import scala.util.Random
24+
2225
import org.scalatest.FunSuite
2326

2427
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
@@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
9396
case t: Throwable =>
9497
}
9598
}
99+
100+
test("kFold") {
101+
val data = sc.parallelize(1 to 100, 2)
102+
val collectedData = data.collect().sorted
103+
val twoFoldedRdd = MLUtils.kFold(data, 2, 1)
104+
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted)
105+
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted)
106+
for (folds <- 2 to 10) {
107+
for (seed <- 1 to 5) {
108+
val foldedRdds = MLUtils.kFold(data, folds, seed)
109+
assert(foldedRdds.size === folds)
110+
foldedRdds.map { case (training, validation) =>
111+
val result = validation.union(training).collect().sorted
112+
val validationSize = validation.collect().size.toFloat
113+
assert(validationSize > 0, "empty validation data")
114+
val p = 1 / folds.toFloat
115+
// Within 3 standard deviations of the mean
116+
val range = 3 * math.sqrt(100 * p * (1 - p))
117+
val expected = 100 * p
118+
val lowerBound = expected - range
119+
val upperBound = expected + range
120+
assert(validationSize > lowerBound,
121+
s"Validation data ($validationSize) smaller than expected ($lowerBound)" )
122+
assert(validationSize < upperBound,
123+
s"Validation data ($validationSize) larger than expected ($upperBound)" )
124+
assert(training.collect().size > 0, "empty training data")
125+
assert(result === collectedData,
126+
"Each training+validation set combined should contain all of the data.")
127+
}
128+
// K fold cross validation should only have each element in the validation set exactly once
129+
assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted ===
130+
data.collect().sorted)
131+
}
132+
}
133+
}
134+
96135
}

0 commit comments

Comments
 (0)