Skip to content

Commit

Permalink
Add global average pool, global max pool ops & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EmergentOrder committed Apr 19, 2021
1 parent 1099e0e commit f61c4e2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ONNXScala/src/main/scala/ndscala/ONNXScalaOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ given NDArrayOps[Tensor] with {
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, S1 <: Dimension #: Dimension #: SNil, PadsBefore <: None.type | Dimension #: Dimension #: SNil, PadsAfter <: None.type | Dimension #: Dimension #: SNil] (arr: Tensor[DType, (Tt,Td,S)]) def averagePool(kernelShape: S1, padsBefore: PadsBefore = None, padsAfter: PadsAfter = None)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter]], s1: ShapeOf[S1]): Tensor[DType, (Tt,Td,PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter])] = onnx.AveragePoolV10("avgpool", X = arr, padsBefore = padsBefore, padsAfter = padsAfter, kernel_shape = kernelShape)
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, S1 <: Dimension #: Dimension #: SNil, PadsBefore <: None.type | Dimension #: Dimension #: SNil, PadsAfter <: None.type | Dimension #: Dimension #: SNil] (arr: Tensor[DType, (Tt,Td,S)]) def maxPool(kernelShape: S1, padsBefore: PadsBefore = None, padsAfter: PadsAfter = None)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter]], s1: ShapeOf[S1]): Tensor[DType, (Tt,Td,PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter])] = onnx.MaxPoolV10("maxpool", X = arr, padsBefore = padsBefore, padsAfter = padsAfter, kernel_shape = kernelShape)

extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil] (arr: Tensor[DType, (Tt,Td,S)]) def globalAveragePool()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[N #: C #: 1 #: 1 #: SNil]): Tensor[DType, (Tt,Td,N #: C #: 1 #: 1 #: SNil)] = onnx.GlobalAveragePoolV1("globalavgpool", X = arr)
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil] (arr: Tensor[DType, (Tt,Td,S)]) def globalMaxPool()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[N #: C #: 1 #: 1 #: SNil]): Tensor[DType, (Tt,Td,N #: C #: 1 #: 1 #: SNil)] = onnx.GlobalMaxPoolV1("globalmaxpool", X = arr)

extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def reciprocal()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): Tensor[DType, (Tt,Td,S)] = onnx.ReciprocalV6("reciprocal", arr)


Expand Down
16 changes: 16 additions & 0 deletions ONNXScala/src/test/scala/ndscala/ONNXScalaNDArraySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ type TD = "TensorShapeDenotation" ##: TSNil
doAssert((arr.lrn(size=3)) ==== expectedResult)
}

"Tensor" should "global average pool" in {
//NCHW tensor, 3 channels, 1 pixel
val arr = Tensor(Array(-1.0f, 0.0f, 1.0f, 2.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 4 #: SNil)
val expectedResult = Tensor(Array(0.5f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 1 #: SNil)
val result: Tensor[Float, Tuple3["TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 1 #: SNil]] = arr.globalAveragePool()
doAssert((result) ==== expectedResult)
}

"Tensor" should "average pool" in {
//NCHW tensor, 3 channels, 1 pixel
val arr = Tensor(Array(-1.0f, 0.0f, 1.0f, 2.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 4 #: SNil)
Expand All @@ -293,6 +301,14 @@ type TD = "TensorShapeDenotation" ##: TSNil
doAssert((result) ==== expectedResult)
}

"Tensor" should "global max pool" in {
//NCHW tensor, 3 channels, 1 pixel
val arr = Tensor(Array(-1.0f, 0.0f, 1.0f, 2.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 4 #: SNil)
val expectedResult = Tensor(Array(2.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 1 #: SNil)
val result: Tensor[Float, Tuple3["TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 1 #: SNil]] = arr.globalMaxPool()
doAssert((result) ==== expectedResult)
}

"Tensor" should "max pool" in {
//NCHW tensor, 3 channels, 1 pixel
val arr = Tensor(Array(-1.0f, 0.0f, 1.0f, 2.0f),"TensorTypeDenotation", "TensorShapeDenotation" ##: "TensorShapeDenotation" ##: TSNil, 1 #: 1 #: 1 #: 4 #: SNil)
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/ndscala/NDArrayOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ trait NDArrayOps[SomeNDArray[_ <: AllSupported, _ <: Axes]] {

extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, S1 <: Dimension #: Dimension #: SNil, PadsBefore <: None.type | Dimension #: Dimension #: SNil, PadsAfter <: None.type | Dimension #: Dimension #: SNil] (arr: SomeNDArray[DType, (Tt,Td,S)]) def averagePool(kernelShape: S1, padsBefore: PadsBefore = None, padsAfter: PadsAfter = None)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter]], s1: ShapeOf[S1]): SomeNDArray[DType, (Tt,Td,PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter])]
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Dimension #: Dimension #: Dimension #: Dimension #: SNil, S1 <: Dimension #: Dimension #: SNil, PadsBefore <: None.type | Dimension #: Dimension #: SNil, PadsAfter <: None.type | Dimension #: Dimension #: SNil] (arr: SomeNDArray[DType, (Tt,Td,S)]) def maxPool(kernelShape: S1, padsBefore: PadsBefore = None, padsAfter: PadsAfter = None)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter]], s1: ShapeOf[S1]): SomeNDArray[DType, (Tt,Td,PaddedShape[PoolShape[S,S1], PadsBefore, PadsAfter])]


extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil] (arr: SomeNDArray[DType, (Tt,Td,S)]) def globalAveragePool()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[N #: C #: 1 #: 1 #: SNil]): SomeNDArray[DType, (Tt,Td,N #: C #: 1 #: 1 #: SNil)]
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, N <: Dimension, C <: Dimension, H <: Dimension, W <: Dimension, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: N #: C #: H #: W #: SNil] (arr: SomeNDArray[DType, (Tt,Td,S)]) def globalMaxPool()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[N #: C #: 1 #: 1 #: SNil]): SomeNDArray[DType, (Tt,Td,N #: C #: 1 #: 1 #: SNil)]

extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def inverse()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): SomeNDArray[DType, (Tt,Td,S)]
extension[DType <: FloatSupported : ClassTag: Numeric : IsFloatSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def constant()(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S]): SomeNDArray[DType, (Tt,Td,S)]

Expand Down

0 comments on commit f61c4e2

Please sign in to comment.