Skip to content

Commit 855d12a

Browse files
committed
[SPARK-5539][MLLIB] LDA guide
This is the LDA user guide from jkbradley with Java and Scala code example. Author: Xiangrui Meng <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes apache#4465 from mengxr/lda-guide and squashes the following commits: 6dcb7d1 [Xiangrui Meng] update java example in the user guide 76169ff [Xiangrui Meng] update java example 36c3ae2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lda-guide c2a1efe [Joseph K. Bradley] Added LDA programming guide, plus Java example (which is in the guide and probably should be removed).
1 parent 4575c56 commit 855d12a

File tree

3 files changed

+215
-1
lines changed

3 files changed

+215
-1
lines changed

data/mllib/sample_lda_data.txt

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
1 2 6 0 2 3 1 1 0 0 3
2+
1 3 0 1 3 0 0 2 0 0 1
3+
1 4 1 0 0 4 9 0 1 2 0
4+
2 1 0 3 0 0 5 0 2 3 9
5+
3 1 1 9 3 0 2 0 0 1 3
6+
4 2 0 3 4 5 1 1 1 4 0
7+
2 1 0 3 0 0 5 0 2 2 9
8+
1 1 1 9 2 1 2 0 0 1 3
9+
4 4 0 3 4 2 1 3 0 0 0
10+
2 8 2 0 3 0 2 0 2 7 2
11+
1 1 1 9 0 2 2 0 0 3 3
12+
4 1 0 0 4 5 1 3 0 1 0

docs/mllib-clustering.md

+128-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ has the following parameters:
5555

5656
Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm:
5757

58-
* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
58+
* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
5959
* calculates the principal eigenvalue and eigenvector
6060
* Clusters each of the input points according to their principal eigenvector component value
6161

@@ -71,6 +71,35 @@ Example outputs for a dataset inspired by the paper - but with five clusters ins
7171
<!-- Images are downsized intentionally to improve quality on retina displays -->
7272
</p>
7373

74+
### Latent Dirichlet Allocation (LDA)
75+
76+
[Latent Dirichlet Allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)
77+
is a topic model which infers topics from a collection of text documents.
78+
LDA can be thought of as a clustering algorithm as follows:
79+
80+
* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset.
81+
* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts.
82+
* Rather than estimating a clustering using a traditional distance, LDA uses a function based
83+
on a statistical model of how text documents are generated.
84+
85+
LDA takes in a collection of documents as vectors of word counts.
86+
It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
87+
on the likelihood function. After fitting on the documents, LDA provides:
88+
89+
* Topics: Inferred topics, each of which is a probability distribution over terms (words).
90+
* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics.
91+
92+
LDA takes the following parameters:
93+
94+
* `k`: Number of topics (i.e., cluster centers)
95+
* `maxIterations`: Limit on the number of iterations of EM used for learning
96+
* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions.
97+
* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions.
98+
* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery.
99+
100+
*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet
101+
support prediction on new documents, and it does not have a Python API. These will be added in the future.
102+
74103
### Examples
75104

76105
#### k-means
@@ -293,6 +322,104 @@ for i in range(2):
293322

294323
</div>
295324

325+
#### Latent Dirichlet Allocation (LDA) Example
326+
327+
In the following example, we load word count vectors representing a corpus of documents.
328+
We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA)
329+
to infer three topics from the documents. The number of desired clusters is passed
330+
to the algorithm. We then output the topics, represented as probability distributions over words.
331+
332+
<div class="codetabs">
333+
<div data-lang="scala" markdown="1">
334+
335+
{% highlight scala %}
336+
import org.apache.spark.mllib.clustering.LDA
337+
import org.apache.spark.mllib.linalg.Vectors
338+
339+
// Load and parse the data
340+
val data = sc.textFile("data/mllib/sample_lda_data.txt")
341+
val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
342+
// Index documents with unique IDs
343+
val corpus = parsedData.zipWithIndex.map(_.swap).cache()
344+
345+
// Cluster the documents into three topics using LDA
346+
val ldaModel = new LDA().setK(3).run(corpus)
347+
348+
// Output topics. Each is a distribution over words (matching word count vectors)
349+
println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):")
350+
val topics = ldaModel.topicsMatrix
351+
for (topic <- Range(0, 3)) {
352+
print("Topic " + topic + ":")
353+
for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
354+
println()
355+
}
356+
{% endhighlight %}
357+
</div>
358+
359+
<div data-lang="java" markdown="1">
360+
{% highlight java %}
361+
import scala.Tuple2;
362+
363+
import org.apache.spark.api.java.*;
364+
import org.apache.spark.api.java.function.Function;
365+
import org.apache.spark.mllib.clustering.DistributedLDAModel;
366+
import org.apache.spark.mllib.clustering.LDA;
367+
import org.apache.spark.mllib.linalg.Matrix;
368+
import org.apache.spark.mllib.linalg.Vector;
369+
import org.apache.spark.mllib.linalg.Vectors;
370+
import org.apache.spark.SparkConf;
371+
372+
public class JavaLDAExample {
373+
public static void main(String[] args) {
374+
SparkConf conf = new SparkConf().setAppName("LDA Example");
375+
JavaSparkContext sc = new JavaSparkContext(conf);
376+
377+
// Load and parse the data
378+
String path = "data/mllib/sample_lda_data.txt";
379+
JavaRDD<String> data = sc.textFile(path);
380+
JavaRDD<Vector> parsedData = data.map(
381+
new Function<String, Vector>() {
382+
public Vector call(String s) {
383+
String[] sarray = s.trim().split(" ");
384+
double[] values = new double[sarray.length];
385+
for (int i = 0; i < sarray.length; i++)
386+
values[i] = Double.parseDouble(sarray[i]);
387+
return Vectors.dense(values);
388+
}
389+
}
390+
);
391+
// Index documents with unique IDs
392+
JavaPairRDD<Long, Vector> corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
393+
new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
394+
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
395+
return doc_id.swap();
396+
}
397+
}
398+
));
399+
corpus.cache();
400+
401+
// Cluster the documents into three topics using LDA
402+
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
403+
404+
// Output topics. Each is a distribution over words (matching word count vectors)
405+
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
406+
+ " words):");
407+
Matrix topics = ldaModel.topicsMatrix();
408+
for (int topic = 0; topic < 3; topic++) {
409+
System.out.print("Topic " + topic + ":");
410+
for (int word = 0; word < ldaModel.vocabSize(); word++) {
411+
System.out.print(" " + topics.apply(word, topic));
412+
}
413+
System.out.println();
414+
}
415+
}
416+
}
417+
{% endhighlight %}
418+
</div>
419+
420+
</div>
421+
422+
296423
In order to run the above application, follow the instructions
297424
provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
298425
section of the Spark
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib;
19+
20+
import scala.Tuple2;
21+
22+
import org.apache.spark.api.java.*;
23+
import org.apache.spark.api.java.function.Function;
24+
import org.apache.spark.mllib.clustering.DistributedLDAModel;
25+
import org.apache.spark.mllib.clustering.LDA;
26+
import org.apache.spark.mllib.linalg.Matrix;
27+
import org.apache.spark.mllib.linalg.Vector;
28+
import org.apache.spark.mllib.linalg.Vectors;
29+
import org.apache.spark.SparkConf;
30+
31+
public class JavaLDAExample {
32+
public static void main(String[] args) {
33+
SparkConf conf = new SparkConf().setAppName("LDA Example");
34+
JavaSparkContext sc = new JavaSparkContext(conf);
35+
36+
// Load and parse the data
37+
String path = "data/mllib/sample_lda_data.txt";
38+
JavaRDD<String> data = sc.textFile(path);
39+
JavaRDD<Vector> parsedData = data.map(
40+
new Function<String, Vector>() {
41+
public Vector call(String s) {
42+
String[] sarray = s.trim().split(" ");
43+
double[] values = new double[sarray.length];
44+
for (int i = 0; i < sarray.length; i++)
45+
values[i] = Double.parseDouble(sarray[i]);
46+
return Vectors.dense(values);
47+
}
48+
}
49+
);
50+
// Index documents with unique IDs
51+
JavaPairRDD<Long, Vector> corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
52+
new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
53+
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
54+
return doc_id.swap();
55+
}
56+
}
57+
));
58+
corpus.cache();
59+
60+
// Cluster the documents into three topics using LDA
61+
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
62+
63+
// Output topics. Each is a distribution over words (matching word count vectors)
64+
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
65+
+ " words):");
66+
Matrix topics = ldaModel.topicsMatrix();
67+
for (int topic = 0; topic < 3; topic++) {
68+
System.out.print("Topic " + topic + ":");
69+
for (int word = 0; word < ldaModel.vocabSize(); word++) {
70+
System.out.print(" " + topics.apply(word, topic));
71+
}
72+
System.out.println();
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)