Skip to content

Commit

Permalink
[SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern m…
Browse files Browse the repository at this point in the history
…ining - PrefixSpan

## What changes were proposed in this pull request?

Change `PrefixSpan` into a class with param setter/getters.
This address issues mentioned here:
apache#20973 (comment)

## How was this patch tested?

UT.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: WeichenXu <[email protected]>

Closes apache#21393 from WeichenXu123/fix_prefix_span.
  • Loading branch information
WeichenXu123 authored and mengxr committed May 23, 2018
1 parent a40ffc6 commit df12506
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 36 deletions.
127 changes: 99 additions & 28 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.ml.fpm

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
Expand All @@ -29,68 +31,137 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
@Experimental
object PrefixSpan {
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {

@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

/**
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
* identified as frequent sequential patterns.
* @group param
*/
@Since("2.4.0")
val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
"sequential pattern. Sequential pattern that appears more than " +
"(minSupport * size-of-the-dataset)." +
"times will be output.", ParamValidators.gtEq(0.0))

/** @group getParam */
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/** @group setParam */
@Since("2.4.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* Param for the maximal pattern length (default: `10`).
* @group param
*/
@Since("2.4.0")
val maxPatternLength = new IntParam(this, "maxPatternLength",
"The maximal length of the sequential pattern.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxPatternLength: Int = $(maxPatternLength)

/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
* If a projected database exceeds this size, another iteration of distributed prefix growth
* is run.
* @group param
*/
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/**
* Param for the name of the sequence column in dataset (default "sequence"), rows with
* nulls in this column are ignored.
* @group param
*/
@Since("2.4.0")
val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
"dataset, rows with nulls in this column are ignored.")

/** @group getParam */
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")

/**
* :: Experimental ::
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
*
* @param dataset A dataset or a dataframe containing a sequence column which is
* {{{Seq[Seq[_]]}}} type
* @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column
* are ignored
* @param minSupport the minimal support level of the sequential pattern, any pattern that
* appears more than (minSupport * size-of-the-dataset) times will be output
* (recommended value: `0.1`).
* @param maxPatternLength the maximal length of the sequential pattern
* (recommended value: `10`).
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the
* internal storage format) allowed in a projected database before
* local processing. If a projected database exceeds this size, another
* iteration of distributed prefix growth is run
* (recommended value: `32000000`).
* @return A `DataFrame` that contains columns of sequence and corresponding frequency.
* The schema of it will be:
* - `sequence: Seq[Seq[T]]` (T is the item type)
* - `freq: Long`
*/
@Since("2.4.0")
def findFrequentSequentialPatterns(
dataset: Dataset[_],
sequenceCol: String,
minSupport: Double,
maxPatternLength: Int,
maxLocalProjDBSize: Long): DataFrame = {

val inputType = dataset.schema(sequenceCol).dataType
def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
val sequenceColParam = $(sequenceCol)
val inputType = dataset.schema(sequenceColParam).dataType
require(inputType.isInstanceOf[ArrayType] &&
inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
s"The input column must be ArrayType and the array element type must also be ArrayType, " +
s"but got $inputType.")


val data = dataset.select(sequenceCol)
val sequences = data.where(col(sequenceCol).isNotNull).rdd
val data = dataset.select(sequenceColParam)
val sequences = data.where(col(sequenceColParam).isNotNull).rdd
.map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)

val mllibPrefixSpan = new mllibPrefixSpan()
.setMinSupport(minSupport)
.setMaxPatternLength(maxPatternLength)
.setMaxLocalProjDBSize(maxLocalProjDBSize)
.setMinSupport($(minSupport))
.setMaxPatternLength($(maxPatternLength))
.setMaxLocalProjDBSize($(maxLocalProjDBSize))

val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
val schema = StructType(Seq(
StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false),
StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)

freqSequences
}

@Since("2.4.0")
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)

}
28 changes: 20 additions & 8 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan projections with multiple partial starts") {
val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence",
minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(1.0)
.setMaxPatternLength(2)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(smallDataset)
.as[(Seq[Seq[Int]], Long)].collect()
val expected = Array(
(Seq(Seq(1)), 1L),
Expand Down Expand Up @@ -90,17 +93,23 @@ class PrefixSpanSuite extends MLTest {

test("PrefixSpan Integer type, variable-size itemsets") {
val df = smallTestData.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
}

test("PrefixSpan input row with nulls") {
val df = (smallTestData :+ null).toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[Int]], Long)].collect()

compareResults[Int](smallTestDataExpectedResult, result)
Expand All @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
val df = smallTestData
.map(seq => seq.map(itemSet => itemSet.map(intToString)))
.toDF("sequence")
val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
val result = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
.setMaxLocalProjDBSize(32000000)
.findFrequentSequentialPatterns(df)
.as[(Seq[Seq[String]], Long)].collect()

val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
Expand Down

0 comments on commit df12506

Please sign in to comment.