Skip to content

Commit

Permalink
Add tests for reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
EmergentOrder committed Apr 12, 2021
1 parent 39b737d commit 017c2b7
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion ONNXScala/src/test/scala/ndscala/ONNXScalaNDArraySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ type TD = "TensorShapeDenotation" ##: TSNil
assertTypeError("arr.reshape[TT, TD, 3 #: 1 #: SNil]()")
}

"Tensor" should "reduceSum wih keepdims on" in {
"Tensor" should "reduceSum with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(126f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceSum[TT, 1 ::: INil, true]) ==== expectedResult)
Expand All @@ -140,6 +140,48 @@ type TD = "TensorShapeDenotation" ##: TSNil
doAssert((arr.reduceSum[TT, 1 ::: INil, false]) ==== expectedResult)
}

"Tensor" should "reduceLogSum with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(4.836282f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceLogSum[TT, 1 ::: INil, true]) ==== expectedResult)
}

"Tensor" should "reduceMax with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(84.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceMax[TT, 1 ::: INil, true]) ==== expectedResult)
}

"Tensor" should "reduceMin with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(42.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceMin[TT, 1 ::: INil, true]) ==== expectedResult)
}

"Tensor" should "reduceProd with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(3528.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceProd[TT, 1 ::: INil, true]) ==== expectedResult)
}

"Tensor" should "reduceSumSquare with keepdims on" in {
val arr = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult = Tensor(Array(8820.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert((arr.reduceSumSquare[TT, 1 ::: INil, true]) ==== expectedResult)
}

"Tensor" should "argmax" in {
val arr: Tensor[Float, (TT, "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)] = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult: Tensor[Long, (TT, "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)] = Tensor(Array(1),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert(arr.argMax[TT, 1 ::: INil, true] ==== expectedResult)
}

"Tensor" should "argmin" in {
val arr: Tensor[Float, (TT, "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)] = Tensor(Array(42f, 84f),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 2 #: SNil)
val expectedResult: Tensor[Long, (TT, "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)] = Tensor(Array(0),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: SNil)
doAssert(arr.argMin[TT, 1 ::: INil, true] ==== expectedResult)
}

"Tensor" should "transpose" in {
val arr = Tensor(Array(1, 2, 3, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 2 #: 2 #: SNil)
doAssert(arr.transpose ==== Tensor(Array(1, 3, 2, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 2 #: 2 #: SNil))
Expand Down

0 comments on commit 017c2b7

Please sign in to comment.