Skip to content

Commit

Permalink
[SPARK-12711][ML] ML StopWordsRemover does not protect itself from co…
Browse files Browse the repository at this point in the history
…lumn name duplication

Fixes problem and verifies fix by test suite.
Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn
and deduplicates SchemaUtils.appendColumn functions.

Author: Grzegorz Chilkiewicz <[email protected]>

Closes apache#10741 from grzegorz-chilkiewicz/master.
  • Loading branch information
grzegorz-chilkiewicz authored and jkbradley committed Feb 2, 2016
1 parent 358300c commit b1835d7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String)
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
val outputFields = schema.fields :+
StructField($(outputCol), inputType, schema($(inputCol)).nullable)
StructType(outputFields)
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
}

override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,10 @@ private[spark] object SchemaUtils {
def appendColumn(
schema: StructType,
colName: String,
dataType: DataType): StructType = {
dataType: DataType,
nullable: Boolean = false): StructType = {
if (colName.isEmpty) return schema
val fieldNames = schema.fieldNames
require(!fieldNames.contains(colName), s"Column $colName already exists.")
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
StructType(outputFields)
appendColumn(schema, StructField(colName, dataType, nullable))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,19 @@ class StopWordsRemoverSuite
.setCaseSensitive(true)
testDefaultReadWrite(t)
}

test("StopWordsRemover output column already exists") {
val outputCol = "expected"
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol(outputCol)
val dataSet = sqlContext.createDataFrame(Seq(
(Seq("The", "the", "swift"), Seq("swift"))
)).toDF("raw", outputCol)

val thrown = intercept[IllegalArgumentException] {
testStopWordsRemover(remover, dataSet)
}
assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
}
}

0 comments on commit b1835d7

Please sign in to comment.