Skip to content

Commit b9e0c93

Browse files
mengxrpwendell
authored andcommitted
[SPARK-1434] [MLLIB] change labelParser from anonymous function to trait
This is a patch to address @mateiz 's comment in apache#245 MLUtils#loadLibSVMData uses an anonymous function for the label parser. Java users won't like it. So I make a trait for LabelParser and provide two implementations: binary and multiclass. Author: Xiangrui Meng <[email protected]> Closes apache#345 from mengxr/label-parser and squashes the following commits: ac44409 [Xiangrui Meng] use singleton objects for label parsers 3b1a7c6 [Xiangrui Meng] add tests for label parsers c2e571c [Xiangrui Meng] rename LabelParser.apply to LabelParser.parse use extends for singleton 11c94e0 [Xiangrui Meng] add return types 7f8eb36 [Xiangrui Meng] change labelParser from annoymous function to trait
1 parent ce8ec54 commit b9e0c93

File tree

4 files changed

+97
-25
lines changed

4 files changed

+97
-25
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.util
19+
20+
/** Trait for label parsers. */
21+
trait LabelParser extends Serializable {
22+
/** Parses a string label into a double label. */
23+
def parse(labelString: String): Double
24+
}
25+
26+
/**
27+
* Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5,
28+
* or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling.
29+
*/
30+
object BinaryLabelParser extends LabelParser {
31+
/** Gets the default instance of BinaryLabelParser. */
32+
def getInstance(): LabelParser = this
33+
34+
/**
35+
* Parses the input label into positive (1.0) if the value is greater than 0.5,
36+
* or negative (0.0) otherwise.
37+
*/
38+
override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
39+
}
40+
41+
/**
42+
* Label parser for multiclass labels, which converts the input label to double.
43+
*/
44+
object MulticlassLabelParser extends LabelParser {
45+
/** Gets the default instance of MulticlassLabelParser. */
46+
def getInstance(): LabelParser = this
47+
48+
override def parse(labelString: String): Double = labelString.toDouble
49+
}

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

+5-23
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,6 @@ object MLUtils {
3838
eps
3939
}
4040

41-
/**
42-
* Multiclass label parser, which parses a string into double.
43-
*/
44-
val multiclassLabelParser: String => Double = _.toDouble
45-
46-
/**
47-
* Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
48-
* or 0.0 (negative) otherwise.
49-
*/
50-
val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0
51-
5241
/**
5342
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
5443
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
@@ -69,7 +58,7 @@ object MLUtils {
6958
def loadLibSVMData(
7059
sc: SparkContext,
7160
path: String,
72-
labelParser: String => Double,
61+
labelParser: LabelParser,
7362
numFeatures: Int,
7463
minSplits: Int): RDD[LabeledPoint] = {
7564
val parsed = sc.textFile(path, minSplits)
@@ -89,7 +78,7 @@ object MLUtils {
8978
}.reduce(math.max)
9079
}
9180
parsed.map { items =>
92-
val label = labelParser(items.head)
81+
val label = labelParser.parse(items.head)
9382
val (indices, values) = items.tail.map { item =>
9483
val indexAndValue = item.split(':')
9584
val index = indexAndValue(0).toInt - 1
@@ -107,14 +96,7 @@ object MLUtils {
10796
* with number of features determined automatically and the default number of partitions.
10897
*/
10998
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
110-
loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)
111-
112-
/**
113-
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
114-
* with number of features specified explicitly and the default number of partitions.
115-
*/
116-
def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
117-
loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)
99+
loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits)
118100

119101
/**
120102
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
@@ -124,7 +106,7 @@ object MLUtils {
124106
def loadLibSVMData(
125107
sc: SparkContext,
126108
path: String,
127-
labelParser: String => Double): RDD[LabeledPoint] =
109+
labelParser: LabelParser): RDD[LabeledPoint] =
128110
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)
129111

130112
/**
@@ -135,7 +117,7 @@ object MLUtils {
135117
def loadLibSVMData(
136118
sc: SparkContext,
137119
path: String,
138-
labelParser: String => Double,
120+
labelParser: LabelParser,
139121
numFeatures: Int): RDD[LabeledPoint] =
140122
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
141123

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.util
19+
20+
import org.scalatest.FunSuite
21+
22+
class LabelParsersSuite extends FunSuite {
23+
test("binary label parser") {
24+
for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) {
25+
assert(parser.parse("+1") === 1.0)
26+
assert(parser.parse("1") === 1.0)
27+
assert(parser.parse("0") === 0.0)
28+
assert(parser.parse("-1") === 0.0)
29+
}
30+
}
31+
32+
test("multiclass label parser") {
33+
for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) {
34+
assert(parser.parse("0") == 0.0)
35+
assert(parser.parse("+1") === 1.0)
36+
assert(parser.parse("1") === 1.0)
37+
assert(parser.parse("2") === 2.0)
38+
assert(parser.parse("3") === 3.0)
39+
}
40+
}
41+
}

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
8080
Files.write(lines, file, Charsets.US_ASCII)
8181
val path = tempDir.toURI.toString
8282

83-
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
83+
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
8484
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
8585

8686
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
@@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
9393
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
9494
}
9595

96-
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
96+
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
9797
assert(multiclassPoints.length === 3)
9898
assert(multiclassPoints(0).label === 1.0)
9999
assert(multiclassPoints(1).label === -1.0)

0 commit comments

Comments
 (0)