Skip to content

Commit

Permalink
fixbug in ftlr
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzwang committed Dec 20, 2018
1 parent 6db4f25 commit 90b872a
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ public static Matrix iftrldelta(Matrix m1, Matrix m2, double alpha) {
}

public static Matrix ftrldelta(Matrix m1, Matrix m2, double alpha) {
return BinaryMatrixExecutor.apply(m1, false, m2, false, new FtrlDelta(true, alpha));
return BinaryMatrixExecutor.apply(m1, false, m2, false, new FtrlDelta(false, alpha));
}

public static Vector iftrldelta(Vector v1, Vector v2, double alpha) {
return BinaryExecutor.apply(v1, v2, new FtrlDelta(true, alpha));
}

public static Vector ftrldelta(Vector v1, Vector v2, double alpha) {
return BinaryExecutor.apply(v1, v2, new FtrlDelta(true, alpha));
return BinaryExecutor.apply(v1, v2, new FtrlDelta(false, alpha));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package com.tencent.angel.ml.math2.matrix.BlasMatrixTest

import com.tencent.angel.ml.math2.MFactory
import com.tencent.angel.ml.math2.matrix.{BlasDoubleMatrix, BlasFloatMatrix}
import com.tencent.angel.ml.math2.ufuncs.Ufuncs
import org.scalatest.FunSuite

class BalsTest extends FunSuite {
val data1 = Array[Double](
1.3, 2.7, 3.2, 5.1,
3.0, 8.0, 9.5, 4.7,
2.6, 8.3, 5.5, 8.9)

val data2 = Array[Double](
1.2, 5.7,
3.1, 4.1,
3.8, 3.5,
8.4, 3.6)

val data1f = Array[Float](
1.3f, 2.7f, 3.2f, 5.1f,
3.0f, 8.0f, 9.5f, 4.7f,
2.6f, 8.3f, 5.5f, 8.9f)

val data2f = Array[Float](
1.2f, 5.7f,
3.1f, 4.1f,
3.8f, 3.5f,
8.4f, 3.6f)

test("Double-NN") {
val mat1 = MFactory.denseDoubleMatrix(3, 4, data1)
val mat2 = MFactory.denseDoubleMatrix(4, 2, data2)
val res = Ufuncs.dot(mat1, false, mat2, false).asInstanceOf[BlasDoubleMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0
(0 until 4).foreach { k =>
tmp += data1(rId * 4 + k) * data2(2 * k + cId)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}

}

test("Double-NT") {
val mat1 = MFactory.denseDoubleMatrix(3, 4, data1)
val mat2 = MFactory.denseDoubleMatrix(2, 4, data2)
val res = Ufuncs.dot(mat1, false, mat2, true).asInstanceOf[BlasDoubleMatrix]


(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0
(0 until 4).foreach { k =>
tmp += data1(rId * 4 + k) * data2(cId * 4 + k)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}

test("Double-TN") {
val mat1 = MFactory.denseDoubleMatrix(4, 3, data1)
val mat2 = MFactory.denseDoubleMatrix(4, 2, data2)
val res = Ufuncs.dot(mat1, true, mat2, false).asInstanceOf[BlasDoubleMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0
(0 until 4).foreach { k =>
tmp += data1(k * 3 + rId) * data2(2 * k + cId)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}

test("Double-TT") {
val mat1 = MFactory.denseDoubleMatrix(4, 3, data1)
val mat2 = MFactory.denseDoubleMatrix(2, 4, data2)
val res = Ufuncs.dot(mat1, true, mat2, true).asInstanceOf[BlasDoubleMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0
(0 until 4).foreach { k =>
tmp += data1(k * 3 + rId) * data2(cId * 4 + k)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}

test("Float-NN") {
val mat1 = MFactory.denseFloatMatrix(3, 4, data1f)
val mat2 = MFactory.denseFloatMatrix(4, 2, data2f)
val res = Ufuncs.dot(mat1, false, mat2, false).asInstanceOf[BlasFloatMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0f
(0 until 4).foreach { k =>
tmp += data1f(rId * 4 + k) * data2f(2 * k + cId)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}

}

test("Float-NT") {
val mat1 = MFactory.denseFloatMatrix(3, 4, data1f)
val mat2 = MFactory.denseFloatMatrix(2, 4, data2f)
val res = Ufuncs.dot(mat1, false, mat2, true).asInstanceOf[BlasFloatMatrix]


(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0f
(0 until 4).foreach { k =>
tmp += data1f(rId * 4 + k) * data2f(cId * 4 + k)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}

test("Float-TN") {
val mat1 = MFactory.denseFloatMatrix(4, 3, data1f)
val mat2 = MFactory.denseFloatMatrix(4, 2, data2f)
val res = Ufuncs.dot(mat1, true, mat2, false).asInstanceOf[BlasFloatMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0f
(0 until 4).foreach { k =>
tmp += data1f(k * 3 + rId) * data2f(2 * k + cId)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}

test("Float-TT") {
val mat1 = MFactory.denseFloatMatrix(4, 3, data1f)
val mat2 = MFactory.denseFloatMatrix(2, 4, data2f)
val res = Ufuncs.dot(mat1, true, mat2, true).asInstanceOf[BlasFloatMatrix]

(0 until 3).foreach { rId =>
(0 until 2).foreach { cId =>
var tmp = 0.0f
(0 until 4).foreach { k =>
tmp += data1f(k * 3 + rId) * data2f(cId * 4 + k)
}

println(f"$tmp%.2f, ${res.get(rId, cId)}%.2f, ${tmp == res.get(rId, cId)}")
assert(tmp == res.get(rId, cId))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ void update(Vector[] rows, int factor, double[] scalars) {
for (int f = 0; f < factor; f++) {
Vector weight = rows[f];
Vector zModel = rows[f + factor];
Vector nModel = rows[f + 2 * factor];
Vector qModel = rows[f + 2 * factor];
Vector gradient = rows[f + 3 * factor];

Vector delta = OptFuncs.ftrldelta(nModel, gradient, alpha);
Ufuncs.iaxpy2(nModel, gradient, 1);
zModel.iadd(gradient.sub(delta.mul(weight)));
Vector sigma = OptFuncs.ftrldelta(qModel, gradient, alpha);
Ufuncs.iaxpy2(qModel, gradient, 1);
zModel.iadd(gradient.sub(sigma.mul(weight)));

Vector newWeight = Ufuncs.ftrlthreshold(zModel, nModel, alpha, beta, lambda1, lambda2);
Vector newWeight = Ufuncs.ftrlthreshold(zModel, qModel, alpha, beta, lambda1, lambda2);
weight.setStorage(newWeight.getStorage());

gradient.clear();
gradient.imul(0.0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import com.tencent.angel.ml.math2.vector._
import com.tencent.angel.ml.math2.{MFactory, VFactory}
import it.unimi.dsi.fastutil.ints.IntOpenHashSet
import it.unimi.dsi.fastutil.longs.LongOpenHashSet
import org.apache.commons.logging.LogFactory
import org.apache.commons.logging.{Log, LogFactory}

import scala.util.Sorting.quickSort


class PlaceHolder(val conf: SharedConf) extends Serializable {
val LOG = LogFactory.getLog(classOf[PlaceHolder])
private val LOG: Log = LogFactory.getLog(classOf[PlaceHolder])

def this() = this(SharedConf.get())

Expand Down

0 comments on commit 90b872a

Please sign in to comment.