Skip to content

Commit

Permalink
[SPARK-7402] [ML] JSON SerDe for standard param types
Browse files Browse the repository at this point in the history
This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley

Author: Xiangrui Meng <[email protected]>

Closes apache#9090 from mengxr/SPARK-7402.
  • Loading branch information
mengxr committed Oct 13, 2015
1 parent c75f058 commit 2b574f5
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 0 deletions.
169 changes: 169 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.JavaConverters._

import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.util.Identifiable

Expand Down Expand Up @@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)

/** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */
def jsonEncode(value: T): String = {
value match {
case x: String =>
compact(render(JString(x)))
case _ =>
throw new NotImplementedError(
"The default jsonEncode only supports string. " +
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
}
}

/** Decodes a param value from JSON. */
def jsonDecode(json: String): T = {
parse(json) match {
case JString(x) =>
x.asInstanceOf[T]
case _ =>
throw new NotImplementedError(
"The default jsonDecode only supports string. " +
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
}

override final def toString: String = s"${parent}__$name"

override final def hashCode: Int = toString.##
Expand Down Expand Up @@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>

/** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value)

override def jsonEncode(value: Double): String = {
compact(render(DoubleParam.jValueEncode(value)))
}

override def jsonDecode(json: String): Double = {
DoubleParam.jValueDecode(parse(json))
}
}

private[param] object DoubleParam {
/** Encodes a param value into JValue. */
def jValueEncode(value: Double): JValue = {
value match {
case _ if value.isNaN =>
JString("NaN")
case Double.NegativeInfinity =>
JString("-Inf")
case Double.PositiveInfinity =>
JString("Inf")
case _ =>
JDouble(value)
}
}

/** Decodes a param value from JValue. */
def jValueDecode(jValue: JValue): Double = {
jValue match {
case JString("NaN") =>
Double.NaN
case JString("-Inf") =>
Double.NegativeInfinity
case JString("Inf") =>
Double.PositiveInfinity
case JDouble(x) =>
x
case _ =>
throw new IllegalArgumentException(s"Cannot decode $jValue to Double.")
}
}
}

/**
Expand All @@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea

/** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)

override def jsonEncode(value: Int): String = {
compact(render(JInt(value)))
}

override def jsonDecode(json: String): Int = {
implicit val formats = DefaultFormats
parse(json).extract[Int]
}
}

/**
Expand All @@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo

/** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)

override def jsonEncode(value: Float): String = {
compact(render(FloatParam.jValueEncode(value)))
}

override def jsonDecode(json: String): Float = {
FloatParam.jValueDecode(parse(json))
}
}

private object FloatParam {

/** Encodes a param value into JValue. */
def jValueEncode(value: Float): JValue = {
value match {
case _ if value.isNaN =>
JString("NaN")
case Float.NegativeInfinity =>
JString("-Inf")
case Float.PositiveInfinity =>
JString("Inf")
case _ =>
JDouble(value)
}
}

/** Decodes a param value from JValue. */
def jValueDecode(jValue: JValue): Float = {
jValue match {
case JString("NaN") =>
Float.NaN
case JString("-Inf") =>
Float.NegativeInfinity
case JString("Inf") =>
Float.PositiveInfinity
case JDouble(x) =>
x.toFloat
case _ =>
throw new IllegalArgumentException(s"Cannot decode $jValue to Float.")
}
}
}

/**
Expand All @@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool

/** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)

override def jsonEncode(value: Long): String = {
compact(render(JInt(value)))
}

override def jsonDecode(json: String): Long = {
implicit val formats = DefaultFormats
parse(json).extract[Long]
}
}

/**
Expand All @@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV

/** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)

override def jsonEncode(value: Boolean): String = {
compact(render(JBool(value)))
}

override def jsonDecode(json: String): Boolean = {
implicit val formats = DefaultFormats
parse(json).extract[Boolean]
}
}

/**
Expand All @@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)

override def jsonEncode(value: Array[String]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq))
}

override def jsonDecode(json: String): Array[String] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[String]].toArray
}
}

/**
Expand All @@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
w(value.asScala.map(_.asInstanceOf[Double]).toArray)

override def jsonEncode(value: Array[Double]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq.map(DoubleParam.jValueEncode)))
}

override def jsonDecode(json: String): Array[Double] = {
parse(json) match {
case JArray(values) =>
values.map(DoubleParam.jValueDecode).toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].")
}
}
}

/**
Expand All @@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
w(value.asScala.map(_.asInstanceOf[Int]).toArray)

override def jsonEncode(value: Array[Int]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq))
}

override def jsonDecode(json: String): Array[Int] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[Int]].toArray
}
}

/**
Expand Down
114 changes: 114 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite

class ParamsSuite extends SparkFunSuite {

test("json encode/decode") {
val dummy = new Params {
override def copy(extra: ParamMap): Params = defaultCopy(extra)

override val uid: String = "dummy"
}

{ // BooleanParam
val param = new BooleanParam(dummy, "name", "doc")
for (value <- Seq(true, false)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // IntParam
val param = new IntParam(dummy, "name", "doc")
for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // LongParam
val param = new LongParam(dummy, "name", "doc")
for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // FloatParam
val param = new FloatParam(dummy, "name", "doc")
for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
if (value.isNaN) {
assert(decoded.isNaN)
} else {
assert(decoded === value)
}
}
}

{ // DoubleParam
val param = new DoubleParam(dummy, "name", "doc")
for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
if (value.isNaN) {
assert(decoded.isNaN)
} else {
assert(decoded === value)
}
}
}

{ // StringParam
val param = new Param[String](dummy, "name", "doc")
// Currently we do not support null.
for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // IntArrayParam
val param = new IntArrayParam(dummy, "name", "doc")
val values: Seq[Array[Int]] = Seq(
Array(),
Array(1),
Array(Int.MinValue, 0, Int.MaxValue))
for (value <- values) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}

{ // DoubleArrayParam
val param = new DoubleArrayParam(dummy, "name", "doc")
val values: Seq[Array[Double]] = Seq(
Array(),
Array(1.0),
Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
for (value <- values) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
assert(decoded.length === value.length)
decoded.zip(value).foreach { case (actual, expected) =>
if (expected.isNaN) {
assert(actual.isNaN)
} else {
assert(actual === expected)
}
}
}
}

{ // StringArrayParam
val param = new StringArrayParam(dummy, "name", "doc")
val values: Seq[Array[String]] = Seq(
Array(),
Array(""),
Array("", "1", "abc", "quote\"", "newline\n"))
for (value <- values) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
}

test("param") {
val solver = new TestParams()
val uid = solver.uid
Expand Down

0 comments on commit 2b574f5

Please sign in to comment.