Skip to content

Commit

Permalink
support uint8
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Mar 21, 2022
1 parent 0b49b69 commit 9529502
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/main/scala/ai/autodeploy/serving/deploy/OnnxModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
}


private def convertToTensor(name: String, tensorInfo: TensorInfo, inputValue: Option[Any], inputShape: Option[Seq[Long]] = None): OnnxTensor = inputValue match {
private def convertToTensor(name: String, tensorInfo: TensorInfo, inputValue: Option[Any], inputShape: Option[Seq[_]] = None): OnnxTensor = inputValue match {
case Some(value) => {
import OnnxJavaType._
val expectedShape = tensorInfo.getShape
value match {
case (v, s: Seq[Long]) => {
case (v, s: Seq[_]) => {
convertToTensor(name, tensorInfo, Option(v), Option(s))
}
case buffer: ByteBuffer => {
val shape = inputShape.map(x => x.toArray).getOrElse(expectedShape)
val shape = inputShape.map(x => convertShape(x)).getOrElse(expectedShape)
val convertedShape = if (isDynamicShape(expectedShape)) shape else expectedShape

tensorInfo.`type` match {
Expand All @@ -193,18 +193,21 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
OnnxTensor.createTensor(env, buffer.asLongBuffer(), convertedShape)
}
case BOOL => {
OnnxTensor.createTensor(env, buffer, convertedShape)
???
}
case STRING => {
???
}
case UINT8 => {
OnnxTensor.createTensor(env, buffer, convertedShape, OnnxJavaType.UINT8)
}
case UNKNOWN => {
throw UnknownDataTypeException(name)
}
}
}
case _ => {
val shape = inputShape.map(x => x.toArray).getOrElse(shapeOfValue(value))
val shape = inputShape.map(x => convertShape(x)).getOrElse(shapeOfValue(value))
val count = elementCount(shape)

// The expected shape could contain dynamic axes that take -1
Expand Down Expand Up @@ -248,6 +251,10 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
val data = copyToBuffer[String](intCount, value)
OnnxTensor.createTensor(env, data, convertedShape)
}
case UINT8 => {
val data = copyToBuffer[Byte](intCount, value)
OnnxTensor.createTensor(env, ByteBuffer.wrap(data), convertedShape, OnnxJavaType.UINT8)
}
case UNKNOWN => {
throw UnknownDataTypeException(name)
}
Expand Down Expand Up @@ -290,6 +297,18 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
pos + 1
}
}

private def convertShape(shape: Seq[_]): Array[Long] = {
val result = Array.ofDim[Long](shape.length)
shape.zipWithIndex.foreach(x =>result(x._2) = {
x._1 match {
case element: Long => element
case element: Number => element.longValue()
case _ => x.toString().toLong
}
})
result
}
}


Expand Down

0 comments on commit 9529502

Please sign in to comment.