Skip to content

Commit 180fd3e

Browse files
BryanCutlersrowen
authored andcommitted
[SPARK-16421][EXAMPLES][ML] Improve ML Example Outputs
## What changes were proposed in this pull request? Improve example outputs to better reflect the functionality that is being presented. This mostly consisted of modifying what was printed at the end of the example, such as calling show() with truncate=False, but sometimes required minor tweaks in the example data to get relevant output. Explicitly set parameters when they are used as part of the example. Fixed Java examples that failed to run because of using old-style MLlib Vectors or problem with schema. Synced examples between different APIs. ## How was this patch tested? Ran each example for Scala, Python, and Java and made sure output was legible on a terminal of width 100. Author: Bryan Cutler <[email protected]> Closes apache#14308 from BryanCutler/ml-examples-improve-output-SPARK-16260.
1 parent 2460f03 commit 180fd3e

File tree

85 files changed

+427
-2757
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+427
-2757
lines changed

data/mllib/lr-data/random.data

-1,000
This file was deleted.

data/mllib/lr_data.txt

-1,000
This file was deleted.

data/mllib/sample_tree_data.csv

-569
This file was deleted.

examples/src/main/java/org/apache/spark/examples/JavaPageRank.java

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
*
4646
* This is an example implementation for learning how to use Spark. For more conventional use,
4747
* please refer to org.apache.spark.graphx.lib.PageRank
48+
*
49+
* Example Usage:
50+
* <pre>
51+
* bin/run-example JavaPageRank data/mllib/pagerank_data.txt 10
52+
* </pre>
4853
*/
4954
public final class JavaPageRank {
5055
private static final Pattern SPACES = Pattern.compile("\\s+");

examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public static void main(String[] args) {
7171
AFTSurvivalRegressionModel model = aft.fit(training);
7272

7373
// Print the coefficients, intercept and scale parameter for AFT survival regression
74-
System.out.println("Coefficients: " + model.coefficients() + " Intercept: "
75-
+ model.intercept() + " Scale: " + model.scale());
74+
System.out.println("Coefficients: " + model.coefficients());
75+
System.out.println("Intercept: " + model.intercept());
76+
System.out.println("Scale: " + model.scale());
7677
model.transform(training).show(false);
7778
// $example off$
7879

examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,18 @@ public static void main(String[] args) {
5151
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
5252
});
5353
Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);
54+
5455
Binarizer binarizer = new Binarizer()
5556
.setInputCol("feature")
5657
.setOutputCol("binarized_feature")
5758
.setThreshold(0.5);
59+
5860
Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);
59-
Dataset<Row> binarizedFeatures = binarizedDataFrame.select("binarized_feature");
60-
for (Row r : binarizedFeatures.collectAsList()) {
61-
Double binarized_value = r.getDouble(0);
62-
System.out.println(binarized_value);
63-
}
61+
62+
System.out.println("Binarizer output with Threshold = " + binarizer.getThreshold());
63+
binarizedDataFrame.show();
6464
// $example off$
65+
6566
spark.stop();
6667
}
6768
}

examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ public static void main(String[] args) {
4444
double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};
4545

4646
List<Row> data = Arrays.asList(
47+
RowFactory.create(-999.9),
4748
RowFactory.create(-0.5),
4849
RowFactory.create(-0.3),
4950
RowFactory.create(0.0),
50-
RowFactory.create(0.2)
51+
RowFactory.create(0.2),
52+
RowFactory.create(999.9)
5153
);
5254
StructType schema = new StructType(new StructField[]{
5355
new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
@@ -61,8 +63,11 @@ public static void main(String[] args) {
6163

6264
// Transform original data into its bucket index.
6365
Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
66+
67+
System.out.println("Bucketizer output with " + (bucketizer.getSplits().length-1) + " buckets");
6468
bucketedData.show();
6569
// $example off$
70+
6671
spark.stop();
6772
}
6873
}

examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java

+4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ public static void main(String[] args) {
6363
.setOutputCol("selectedFeatures");
6464

6565
Dataset<Row> result = selector.fit(df).transform(df);
66+
67+
System.out.println("ChiSqSelector output with top " + selector.getNumTopFeatures()
68+
+ " features selected");
6669
result.show();
70+
6771
// $example off$
6872
spark.stop();
6973
}

examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public static void main(String[] args) {
6161
.setInputCol("text")
6262
.setOutputCol("feature");
6363

64-
cvModel.transform(df).show();
64+
cvModel.transform(df).show(false);
6565
// $example off$
6666

6767
spark.stop();

examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,17 @@ public static void main(String[] args) {
5151
new StructField("features", new VectorUDT(), false, Metadata.empty()),
5252
});
5353
Dataset<Row> df = spark.createDataFrame(data, schema);
54+
5455
DCT dct = new DCT()
5556
.setInputCol("features")
5657
.setOutputCol("featuresDCT")
5758
.setInverse(false);
59+
5860
Dataset<Row> dctDf = dct.transform(df);
59-
dctDf.select("featuresDCT").show(3);
61+
62+
dctDf.select("featuresDCT").show(false);
6063
// $example off$
64+
6165
spark.stop();
6266
}
6367
}

examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ public static void main(String[] args) {
5454

5555
// Output the parameters of the mixture model
5656
for (int i = 0; i < model.getK(); i++) {
57-
System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",
58-
model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
57+
System.out.printf("Gaussian %d:\nweight=%f\nmu=%s\nsigma=\n%s\n\n",
58+
i, model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
5959
}
6060
// $example off$
6161

examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java

+14-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.Arrays;
2525
import java.util.List;
2626

27+
import org.apache.spark.ml.attribute.Attribute;
2728
import org.apache.spark.ml.feature.IndexToString;
2829
import org.apache.spark.ml.feature.StringIndexer;
2930
import org.apache.spark.ml.feature.StringIndexerModel;
@@ -63,11 +64,23 @@ public static void main(String[] args) {
6364
.fit(df);
6465
Dataset<Row> indexed = indexer.transform(df);
6566

67+
System.out.println("Transformed string column '" + indexer.getInputCol() + "' " +
68+
"to indexed column '" + indexer.getOutputCol() + "'");
69+
indexed.show();
70+
71+
StructField inputColSchema = indexed.schema().apply(indexer.getOutputCol());
72+
System.out.println("StringIndexer will store labels in output column metadata: " +
73+
Attribute.fromStructField(inputColSchema).toString() + "\n");
74+
6675
IndexToString converter = new IndexToString()
6776
.setInputCol("categoryIndex")
6877
.setOutputCol("originalCategory");
6978
Dataset<Row> converted = converter.transform(indexed);
70-
converted.select("id", "originalCategory").show();
79+
80+
System.out.println("Transformed indexed column '" + converter.getInputCol() + "' back to " +
81+
"original string column '" + converter.getOutputCol() + "' using labels in metadata");
82+
converted.select("id", "categoryIndex", "originalCategory").show();
83+
7184
// $example off$
7285
spark.stop();
7386
}

examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ public static void main(String[] args) {
5050
IsotonicRegression ir = new IsotonicRegression();
5151
IsotonicRegressionModel model = ir.fit(dataset);
5252

53-
System.out.println("Boundaries in increasing order: " + model.boundaries());
54-
System.out.println("Predictions associated with the boundaries: " + model.predictions());
53+
System.out.println("Boundaries in increasing order: " + model.boundaries() + "\n");
54+
System.out.println("Predictions associated with the boundaries: " + model.predictions() + "\n");
5555

5656
// Makes predictions.
5757
model.transform(dataset).show();

examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java

+23-5
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,20 @@
1818
package org.apache.spark.examples.ml;
1919

2020
// $example on$
21+
import java.util.Arrays;
22+
import java.util.List;
23+
2124
import org.apache.spark.ml.feature.MaxAbsScaler;
2225
import org.apache.spark.ml.feature.MaxAbsScalerModel;
26+
import org.apache.spark.ml.linalg.Vectors;
27+
import org.apache.spark.ml.linalg.VectorUDT;
2328
import org.apache.spark.sql.Dataset;
2429
import org.apache.spark.sql.Row;
30+
import org.apache.spark.sql.RowFactory;
31+
import org.apache.spark.sql.types.DataTypes;
32+
import org.apache.spark.sql.types.Metadata;
33+
import org.apache.spark.sql.types.StructField;
34+
import org.apache.spark.sql.types.StructType;
2535
// $example off$
2636
import org.apache.spark.sql.SparkSession;
2737

@@ -34,10 +44,17 @@ public static void main(String[] args) {
3444
.getOrCreate();
3545

3646
// $example on$
37-
Dataset<Row> dataFrame = spark
38-
.read()
39-
.format("libsvm")
40-
.load("data/mllib/sample_libsvm_data.txt");
47+
List<Row> data = Arrays.asList(
48+
RowFactory.create(0, Vectors.dense(1.0, 0.1, -8.0)),
49+
RowFactory.create(1, Vectors.dense(2.0, 1.0, -4.0)),
50+
RowFactory.create(2, Vectors.dense(4.0, 10.0, 8.0))
51+
);
52+
StructType schema = new StructType(new StructField[]{
53+
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
54+
new StructField("features", new VectorUDT(), false, Metadata.empty())
55+
});
56+
Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
57+
4158
MaxAbsScaler scaler = new MaxAbsScaler()
4259
.setInputCol("features")
4360
.setOutputCol("scaledFeatures");
@@ -47,8 +64,9 @@ public static void main(String[] args) {
4764

4865
// rescale each feature to range [-1, 1].
4966
Dataset<Row> scaledData = scalerModel.transform(dataFrame);
50-
scaledData.show();
67+
scaledData.select("features", "scaledFeatures").show();
5168
// $example off$
69+
5270
spark.stop();
5371
}
5472

examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java

+25-5
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,20 @@
2020
import org.apache.spark.sql.SparkSession;
2121

2222
// $example on$
23+
import java.util.Arrays;
24+
import java.util.List;
25+
2326
import org.apache.spark.ml.feature.MinMaxScaler;
2427
import org.apache.spark.ml.feature.MinMaxScalerModel;
28+
import org.apache.spark.ml.linalg.Vectors;
29+
import org.apache.spark.ml.linalg.VectorUDT;
2530
import org.apache.spark.sql.Dataset;
2631
import org.apache.spark.sql.Row;
32+
import org.apache.spark.sql.RowFactory;
33+
import org.apache.spark.sql.types.DataTypes;
34+
import org.apache.spark.sql.types.Metadata;
35+
import org.apache.spark.sql.types.StructField;
36+
import org.apache.spark.sql.types.StructType;
2737
// $example off$
2838

2939
public class JavaMinMaxScalerExample {
@@ -34,10 +44,17 @@ public static void main(String[] args) {
3444
.getOrCreate();
3545

3646
// $example on$
37-
Dataset<Row> dataFrame = spark
38-
.read()
39-
.format("libsvm")
40-
.load("data/mllib/sample_libsvm_data.txt");
47+
List<Row> data = Arrays.asList(
48+
RowFactory.create(0, Vectors.dense(1.0, 0.1, -1.0)),
49+
RowFactory.create(1, Vectors.dense(2.0, 1.1, 1.0)),
50+
RowFactory.create(2, Vectors.dense(3.0, 10.1, 3.0))
51+
);
52+
StructType schema = new StructType(new StructField[]{
53+
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
54+
new StructField("features", new VectorUDT(), false, Metadata.empty())
55+
});
56+
Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
57+
4158
MinMaxScaler scaler = new MinMaxScaler()
4259
.setInputCol("features")
4360
.setOutputCol("scaledFeatures");
@@ -47,8 +64,11 @@ public static void main(String[] args) {
4764

4865
// rescale each feature to range [min, max].
4966
Dataset<Row> scaledData = scalerModel.transform(dataFrame);
50-
scaledData.show();
67+
System.out.println("Features scaled to range: [" + scaler.getMin() + ", "
68+
+ scaler.getMax() + "]");
69+
scaledData.select("features", "scaledFeatures").show();
5170
// $example off$
71+
5272
spark.stop();
5373
}
5474
}

examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,34 @@ public static void main(String[] args) {
4141
// Load training data
4242
String path = "data/mllib/sample_multiclass_classification_data.txt";
4343
Dataset<Row> dataFrame = spark.read().format("libsvm").load(path);
44+
4445
// Split the data into train and test
4546
Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
4647
Dataset<Row> train = splits[0];
4748
Dataset<Row> test = splits[1];
49+
4850
// specify layers for the neural network:
4951
// input layer of size 4 (features), two intermediate of size 5 and 4
5052
// and output of size 3 (classes)
5153
int[] layers = new int[] {4, 5, 4, 3};
54+
5255
// create the trainer and set its parameters
5356
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
5457
.setLayers(layers)
5558
.setBlockSize(128)
5659
.setSeed(1234L)
5760
.setMaxIter(100);
61+
5862
// train the model
5963
MultilayerPerceptronClassificationModel model = trainer.fit(train);
64+
6065
// compute accuracy on the test set
6166
Dataset<Row> result = model.transform(test);
6267
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
6368
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
6469
.setMetricName("accuracy");
65-
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
70+
71+
System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));
6672
// $example off$
6773

6874
spark.stop();

examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java

+7-11
Original file line numberDiff line numberDiff line change
@@ -42,29 +42,25 @@ public static void main(String[] args) {
4242

4343
// $example on$
4444
List<Row> data = Arrays.asList(
45-
RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
46-
RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
47-
RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
45+
RowFactory.create(0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
46+
RowFactory.create(1, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
47+
RowFactory.create(2, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
4848
);
4949

5050
StructType schema = new StructType(new StructField[]{
51-
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
51+
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
5252
new StructField(
5353
"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
5454
});
5555

5656
Dataset<Row> wordDataFrame = spark.createDataFrame(data, schema);
5757

58-
NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams");
58+
NGram ngramTransformer = new NGram().setN(2).setInputCol("words").setOutputCol("ngrams");
5959

6060
Dataset<Row> ngramDataFrame = ngramTransformer.transform(wordDataFrame);
61-
62-
for (Row r : ngramDataFrame.select("ngrams", "label").takeAsList(3)) {
63-
java.util.List<String> ngrams = r.getList(0);
64-
for (String ngram : ngrams) System.out.print(ngram + " --- ");
65-
System.out.println();
66-
}
61+
ngramDataFrame.select("ngrams").show(false);
6762
// $example off$
63+
6864
spark.stop();
6965
}
7066
}

examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java

+10-3
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,21 @@ public static void main(String[] args) {
4848

4949
// create the trainer and set its parameters
5050
NaiveBayes nb = new NaiveBayes();
51+
5152
// train the model
5253
NaiveBayesModel model = nb.fit(train);
54+
55+
// Select example rows to display.
56+
Dataset<Row> predictions = model.transform(test);
57+
predictions.show();
58+
5359
// compute accuracy on the test set
54-
Dataset<Row> result = model.transform(test);
55-
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
5660
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
61+
.setLabelCol("label")
62+
.setPredictionCol("prediction")
5763
.setMetricName("accuracy");
58-
System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
64+
double accuracy = evaluator.evaluate(predictions);
65+
System.out.println("Test set accuracy = " + accuracy);
5966
// $example off$
6067

6168
spark.stop();

0 commit comments

Comments
 (0)