Skip to content

Commit

Permalink
[ISSUE-346] Fix L2 loss problem for comp model types
Browse files Browse the repository at this point in the history
  • Loading branch information
paynie committed May 18, 2018
1 parent 8fc799c commit 22df095
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ public TIntDoubleVector elemUpdate(IntDoubleElemUpdater updater, ElemUpdateParam
ElementUpdateOp
op = new ElementUpdateOp(vectors, 0, splitNum, updater, param);
MatrixOpExecutors.execute(op);
op.join();
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ public PartitionKey[] getPartKeys() {
public TLongDoubleVector elemUpdate(LongDoubleElemUpdater updater, ElemUpdateParam param) {
ElementUpdateOp op = new ElementUpdateOp(vectors, 0, splitNum, updater, param);
MatrixOpExecutors.execute(op);
op.join();
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ object GradientDescent {
if (loss.isL2Reg) {
// for l2
L2Loss(loss, w, grad)

wM.increment(grad.timesBy(-1.0 * lr).asInstanceOf[TDoubleVector])
wM.syncClock()
} else if (loss.isL1Reg) {
Expand Down Expand Up @@ -192,7 +191,11 @@ object GradientDescent {

class L2Updater extends IntDoubleElemUpdater {
override def action(index: Int, value: Double, param: ElemUpdateParam): Double = {
value + param.asInstanceOf[L2UpdateParam].getW.get(index) * param.asInstanceOf[L2UpdateParam].getLossRegParam
if(Math.abs(value) > 10e-7) {
value + param.asInstanceOf[L2UpdateParam].getW.get(index) * param.asInstanceOf[L2UpdateParam].getLossRegParam
} else {
value
}
}
}
compSparse.elemUpdate(new L2Updater, new L2UpdateParam(loss.getRegParam, w))
Expand Down Expand Up @@ -221,7 +224,11 @@ object GradientDescent {

class L2Updater extends LongDoubleElemUpdater {
override def action(index: Long, value: Double, param: ElemUpdateParam): Double = {
value + param.asInstanceOf[L2UpdateParam].getW.get(index) * param.asInstanceOf[L2UpdateParam].getLossRegParam
if(Math.abs(value) > 10e-7) {
value + param.asInstanceOf[L2UpdateParam].getW.get(index) * param.asInstanceOf[L2UpdateParam].getLossRegParam
} else {
value
}
}
}
compLong.elemUpdate(new L2Updater, new L2UpdateParam(loss.getRegParam, w))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ case class LibSVMDataParser(splitor: String, featRange: Long, negY: Boolean, has
val len = splits.length

val x = rowType match {
case RowType.T_DOUBLE_DENSE | RowType.T_DOUBLE_SPARSE =>
case RowType.T_DOUBLE_DENSE | RowType.T_DOUBLE_SPARSE | RowType.T_DOUBLE_SPARSE_COMPONENT=>
val keys: Array[Int] = new Array[Int](len)
val vals: Array[Double] = new Array[Double](len)

Expand Down

0 comments on commit 22df095

Please sign in to comment.