Skip to content

Commit

Permalink
Add an option to turn off data validation, test it.
Browse files Browse the repository at this point in the history
Also moves addIntercept to have default true to make it similar
to validateData option
  • Loading branch information
shivaram committed Aug 26, 2013
1 parent c874625 commit dc06b52
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class LogisticRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {

Expand All @@ -71,7 +70,7 @@ class LogisticRegressionWithSGD private (
/**
* Construct a LogisticRegression object with default parameters
*/
def this() = this(1.0, 100, 0.0, 1.0, true)
def this() = this(1.0, 100, 0.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
Expand Down Expand Up @@ -108,7 +107,7 @@ object LogisticRegressionWithSGD {
initialWeights: Array[Double])
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input, initialWeights)
}

Expand All @@ -131,7 +130,7 @@ object LogisticRegressionWithSGD {
miniBatchFraction: Double)
: LogisticRegressionModel =
{
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input)
}

Expand Down
9 changes: 4 additions & 5 deletions mllib/src/main/scala/spark/mllib/classification/SVM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class SVMWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {

val gradient = new HingeGradient()
Expand All @@ -71,7 +70,7 @@ class SVMWithSGD private (
/**
* Construct a SVM object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
new SVMModel(weights, intercept)
Expand Down Expand Up @@ -107,7 +106,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}

Expand All @@ -131,7 +130,7 @@ object SVMWithSGD {
miniBatchFraction: Double)
: SVMModel =
{
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]

val optimizer: Optimizer

protected var addIntercept: Boolean = true

protected var validateData: Boolean = true

/**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Array[Double], intercept: Double): M

protected var addIntercept: Boolean

/**
* Set if the algorithm should add an intercept. Default true.
*/
Expand All @@ -102,6 +104,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}

/**
* Set if the algorithm should validate data before training. Default true.
*/
def setValidateData(validateData: Boolean): this.type = {
this.validateData = validateData
this
}

/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
Expand All @@ -119,7 +129,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {

// Check the data properties before running the optimizer
if (!validators.forall(func => func(input))) {
if (validateData && !validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
}

Expand Down
9 changes: 4 additions & 5 deletions mllib/src/main/scala/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class LassoWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {

Expand All @@ -63,7 +62,7 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0, true)
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
Expand Down Expand Up @@ -98,7 +97,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}

Expand All @@ -121,7 +120,7 @@ object LassoWithSGD {
miniBatchFraction: Double)
: LassoModel =
{
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
intercept[spark.SparkException] {
val model = SVMWithSGD.train(testRDDInvalid, 100)
}

// Turning off data validation should not throw an exception
val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}

0 comments on commit dc06b52

Please sign in to comment.