Skip to content

Commit 22b9a87

Browse files
holdenksrowen
authored andcommitted
[SPARK-10299][ML] word2vec should allow users to specify the window size
Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size . Author: Holden Karau <[email protected]> Author: Holden Karau <[email protected]> Closes apache#8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size.
1 parent 6e1c55e commit 22b9a87

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

+15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params
4949
/** @group getParam */
5050
def getVectorSize: Int = $(vectorSize)
5151

52+
/**
53+
* The window size (context words from [-window, window]) default 5.
54+
* @group expertParam
55+
*/
56+
final val windowSize = new IntParam(
57+
this, "windowSize", "the window size (context words from [-window, window])")
58+
setDefault(windowSize -> 5)
59+
60+
/** @group expertGetParam */
61+
def getWindowSize: Int = $(windowSize)
62+
5263
/**
5364
* Number of partitions for sentences of words.
5465
* Default: 1
@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
106117
/** @group setParam */
107118
def setVectorSize(value: Int): this.type = set(vectorSize, value)
108119

120+
/** @group expertSetParam */
121+
def setWindowSize(value: Int): this.type = set(windowSize, value)
122+
109123
/** @group setParam */
110124
def setStepSize(value: Double): this.type = set(stepSize, value)
111125

@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
131145
.setNumPartitions($(numPartitions))
132146
.setSeed($(seed))
133147
.setVectorSize($(vectorSize))
148+
.setWindowSize($(windowSize))
134149
.fit(input)
135150
copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
136151
}

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging {
125125
this
126126
}
127127

128+
/**
129+
* Sets the window of words (default: 5)
130+
*/
131+
@Since("1.6.0")
132+
def setWindowSize(window: Int): this.type = {
133+
this.window = window
134+
this
135+
}
136+
128137
/**
129138
* Sets minCount, the minimum number of times a token must appear to be included in the word2vec
130139
* model's vocabulary (default: 5).
@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging {
141150
private val MAX_SENTENCE_LENGTH = 1000
142151

143152
/** context words from [-window, window] */
144-
private val window = 5
153+
private var window = 5
145154

146155
private var trainWordsCount = 0
147156
private var vocabSize = 0

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

+40-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
3535
}
3636

3737
test("Word2Vec") {
38-
val sqlContext = new SQLContext(sc)
38+
39+
val sqlContext = this.sqlContext
3940
import sqlContext.implicits._
4041

4142
val sentence = "a b " * 100 + "a c " * 10
@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
7778

7879
test("getVectors") {
7980

80-
val sqlContext = new SQLContext(sc)
81+
val sqlContext = this.sqlContext
8182
import sqlContext.implicits._
8283

8384
val sentence = "a b " * 100 + "a c " * 10
@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
118119

119120
test("findSynonyms") {
120121

121-
val sqlContext = new SQLContext(sc)
122+
val sqlContext = this.sqlContext
122123
import sqlContext.implicits._
123124

124125
val sentence = "a b " * 100 + "a c " * 10
@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
141142
expectedSimilarity.zip(similarity).map {
142143
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
143144
}
145+
}
146+
147+
test("window size") {
148+
149+
val sqlContext = this.sqlContext
150+
import sqlContext.implicits._
151+
152+
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
153+
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
154+
val docDF = doc.zip(doc).toDF("text", "alsotext")
155+
156+
val model = new Word2Vec()
157+
.setVectorSize(3)
158+
.setWindowSize(2)
159+
.setInputCol("text")
160+
.setOutputCol("result")
161+
.setSeed(42L)
162+
.fit(docDF)
144163

164+
val (synonyms, similarity) = model.findSynonyms("a", 6).map {
165+
case Row(w: String, sim: Double) => (w, sim)
166+
}.collect().unzip
167+
168+
// Increase the window size
169+
val biggerModel = new Word2Vec()
170+
.setVectorSize(3)
171+
.setInputCol("text")
172+
.setOutputCol("result")
173+
.setSeed(42L)
174+
.setWindowSize(10)
175+
.fit(docDF)
176+
177+
val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
178+
case Row(w: String, sim: Double) => (w, sim)
179+
}.collect().unzip
180+
// The similarity score should be very different with the larger window
181+
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
145182
}
146183

147184
test("Word2Vec read/write") {

0 commit comments

Comments
 (0)