@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
35
35
}
36
36
37
37
test(" Word2Vec" ) {
38
- val sqlContext = new SQLContext (sc)
38
+
39
+ val sqlContext = this .sqlContext
39
40
import sqlContext .implicits ._
40
41
41
42
val sentence = " a b " * 100 + " a c " * 10
@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
77
78
78
79
test(" getVectors" ) {
79
80
80
- val sqlContext = new SQLContext (sc)
81
+ val sqlContext = this .sqlContext
81
82
import sqlContext .implicits ._
82
83
83
84
val sentence = " a b " * 100 + " a c " * 10
@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
118
119
119
120
test(" findSynonyms" ) {
120
121
121
- val sqlContext = new SQLContext (sc)
122
+ val sqlContext = this .sqlContext
122
123
import sqlContext .implicits ._
123
124
124
125
val sentence = " a b " * 100 + " a c " * 10
@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
141
142
expectedSimilarity.zip(similarity).map {
142
143
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5 )
143
144
}
145
+ }
146
+
147
+ test(" window size" ) {
148
+
149
+ val sqlContext = this .sqlContext
150
+ import sqlContext .implicits ._
151
+
152
+ val sentence = " a q s t q s t b b b s t m s t m q " * 100 + " a c " * 10
153
+ val doc = sc.parallelize(Seq (sentence, sentence)).map(line => line.split(" " ))
154
+ val docDF = doc.zip(doc).toDF(" text" , " alsotext" )
155
+
156
+ val model = new Word2Vec ()
157
+ .setVectorSize(3 )
158
+ .setWindowSize(2 )
159
+ .setInputCol(" text" )
160
+ .setOutputCol(" result" )
161
+ .setSeed(42L )
162
+ .fit(docDF)
144
163
164
+ val (synonyms, similarity) = model.findSynonyms(" a" , 6 ).map {
165
+ case Row (w : String , sim : Double ) => (w, sim)
166
+ }.collect().unzip
167
+
168
+ // Increase the window size
169
+ val biggerModel = new Word2Vec ()
170
+ .setVectorSize(3 )
171
+ .setInputCol(" text" )
172
+ .setOutputCol(" result" )
173
+ .setSeed(42L )
174
+ .setWindowSize(10 )
175
+ .fit(docDF)
176
+
177
+ val (synonymsLarger, similarityLarger) = model.findSynonyms(" a" , 6 ).map {
178
+ case Row (w : String , sim : Double ) => (w, sim)
179
+ }.collect().unzip
180
+ // The similarity score should be very different with the larger window
181
+ assert(math.abs(similarity(5 ) - similarityLarger(5 ) / similarity(5 )) > 1E-5 )
145
182
}
146
183
147
184
test(" Word2Vec read/write" ) {
0 commit comments