Skip to content

Commit

Permalink
support dynamic axes
Browse files Browse the repository at this point in the history
  • Loading branch information
scorebot committed Aug 23, 2020
1 parent bfe0c1a commit d20ff1b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 28 deletions.
46 changes: 28 additions & 18 deletions src/main/scala/ai/autodeploy/serving/deploy/OnnxModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,33 +161,39 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
}


private def convertToTensor(name: String, tensorInfo: TensorInfo, inputValue: Option[Any]): OnnxTensor = inputValue match {
private def convertToTensor(name: String, tensorInfo: TensorInfo, inputValue: Option[Any], inputShape: Option[Seq[Long]] = None): OnnxTensor = inputValue match {
case Some(value) => {
import OnnxJavaType._
val expectedShape = tensorInfo.getShape
value match {
case (v, s: Seq[Long]) => {
convertToTensor(name, tensorInfo, Option(v), Option(s))
}
case buffer: ByteBuffer => {
val shape = inputShape.map(x => x.toArray).getOrElse(expectedShape)
val convertedShape = if (isDynamicShape(expectedShape)) shape else expectedShape

tensorInfo.`type` match {
case FLOAT => {
OnnxTensor.createTensor(env, buffer.asFloatBuffer(), expectedShape)
OnnxTensor.createTensor(env, buffer.asFloatBuffer(), convertedShape)
}
case DOUBLE => {
OnnxTensor.createTensor(env, buffer.asDoubleBuffer(), expectedShape)
OnnxTensor.createTensor(env, buffer.asDoubleBuffer(), convertedShape)
}
case INT8 => {
OnnxTensor.createTensor(env, buffer, expectedShape)
OnnxTensor.createTensor(env, buffer, convertedShape)
}
case INT16 => {
OnnxTensor.createTensor(env, buffer.asShortBuffer(), expectedShape)
OnnxTensor.createTensor(env, buffer.asShortBuffer(), convertedShape)
}
case INT32 => {
OnnxTensor.createTensor(env, buffer.asIntBuffer(), expectedShape)
OnnxTensor.createTensor(env, buffer.asIntBuffer(), convertedShape)
}
case INT64 => {
OnnxTensor.createTensor(env, buffer.asLongBuffer(), expectedShape)
OnnxTensor.createTensor(env, buffer.asLongBuffer(), convertedShape)
}
case BOOL => {
OnnxTensor.createTensor(env, buffer, expectedShape)
OnnxTensor.createTensor(env, buffer, convertedShape)
}
case STRING => {
???
Expand All @@ -198,45 +204,49 @@ class OnnxModel(val session: OrtSession, val env: OrtEnvironment) extends Predic
}
}
case _ => {
val shape = shapeOfValue(value)
val shape = inputShape.map(x => x.toArray).getOrElse(shapeOfValue(value))
val count = elementCount(shape)
if (count != elementCount(expectedShape)) {

// The expected shape could contain dynamic axes that take -1
val expectedCount = elementCount(expectedShape)
if (count % expectedCount != 0) {
throw ShapeMismatchException(shape, expectedShape)
}

val convertedShape = if (isDynamicShape(expectedShape)) shape else expectedShape
val intCount = count.toInt
tensorInfo.`type` match {
case FLOAT => {
val data = copyToBuffer[Float](intCount, value)
OnnxTensor.createTensor(env, FloatBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, FloatBuffer.wrap(data), convertedShape)
}
case DOUBLE => {
val data = copyToBuffer[Double](intCount, value)
OnnxTensor.createTensor(env, DoubleBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, DoubleBuffer.wrap(data), convertedShape)
}
case INT8 => {
val data = copyToBuffer[Byte](intCount, value)
OnnxTensor.createTensor(env, ByteBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, ByteBuffer.wrap(data), convertedShape)
}
case INT16 => {
val data = copyToBuffer[Short](intCount, value)
OnnxTensor.createTensor(env, ShortBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, ShortBuffer.wrap(data), convertedShape)
}
case INT32 => {
val data = copyToBuffer[Int](intCount, value)
OnnxTensor.createTensor(env, IntBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, IntBuffer.wrap(data), convertedShape)
}
case INT64 => {
val data = copyToBuffer[Long](intCount, value)
OnnxTensor.createTensor(env, LongBuffer.wrap(data), expectedShape)
OnnxTensor.createTensor(env, LongBuffer.wrap(data), convertedShape)
}
case BOOL => {
val data = copyToBuffer[Boolean](intCount, value)
OnnxTensor.createTensor(env, OrtUtil.reshape(data, expectedShape))
OnnxTensor.createTensor(env, OrtUtil.reshape(data, convertedShape))
}
case STRING => {
val data = copyToBuffer[String](intCount, value)
OnnxTensor.createTensor(env, data, expectedShape)
OnnxTensor.createTensor(env, data, convertedShape)
}
case UNKNOWN => {
throw UnknownDataTypeException(name)
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/ai/autodeploy/serving/errors/errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.autodeploy.serving.errors

import java.util.Arrays

class BaseException(val message: String) extends Exception(message)

case class InvalidModelException(modelType: String, reason: String) extends
Expand All @@ -28,7 +30,7 @@ case class ModelTypeNotSupportedException(modelType: Option[String]) extends
BaseException(modelType.map(x => s"Model type '${x}' not supported") getOrElse s"Unknown model type")

case class ShapeMismatchException(actual: Array[Long], expected: Array[Long]) extends
BaseException(s"Shape mismatch: ${expected} expected but ${actual} got")
BaseException(s"Shape mismatch: ${Arrays.toString(expected)} expected but ${Arrays.toString(actual)} got")

case class MissingValueException(name: String, `type`: String, shape: Array[Long]) extends
BaseException(s"Missing value for '${name}' in the input request")
Expand All @@ -40,4 +42,4 @@ case class OnnxRuntimeLibraryNotFoundError(reason: String) extends
BaseException(s"Onnx Runtime initialization failed: ${reason}")

case class UnknownContentTypeException(contentType: Option[String] = None) extends
BaseException(contentType.map(x => s"Prediction request takes unknown content type: ${x}") getOrElse s"The required header 'Content-Type' not found")
BaseException(contentType.map(x => s"Prediction request takes unknown content type: ${x}") getOrElse s"The required header 'Content-Type' not found")
14 changes: 7 additions & 7 deletions src/main/scala/ai/autodeploy/serving/protobuf/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,21 @@ package object protobuf {
if (!v.rawData.isEmpty) {
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
v.rawData.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN)
(v.rawData.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN), v.dims)
} else {
v.dataType match {
case FLOAT.index | COMPLEX64.index =>
v.floatData
(v.floatData, v.dims)
case INT32.index | INT16.index | INT8.index | UINT16.index | UINT8.index | BOOL.index | FLOAT16.index =>
v.int32Data
(v.int32Data, v.dims)
case STRING.index =>
v.stringData
(v.stringData, v.dims)
case INT64.index =>
v.int64Data
(v.int64Data, v.dims)
case DOUBLE.index | COMPLEX128.index =>
v.doubleData
(v.doubleData, v.dims)
case UINT32.index | UINT64.index =>
v.uint64Data
(v.uint64Data, v.dims)
case _ =>
null
}
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/ai/autodeploy/serving/utils/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ object Utils {
}

def elementCount(shape: Array[Long]): Long = {
shape.foldLeft(1L)((x, y) => x * y)
// filter the dynamic axes that take -1
shape.filter(x => x != -1).foldLeft(1L)((x, y) => x * y)
}

def shapeOfValue(value: Any): Array[Long] = {
Expand All @@ -160,6 +161,8 @@ object Utils {
result.toArray
}

def isDynamicShape(shape: Array[Long]): Boolean = shape.contains(-1L)

@tailrec
def dimensionOfValue(value: Any, result: ArrayBuffer[Long]): Unit = value match {
case seq: Seq[_] => {
Expand Down

0 comments on commit d20ff1b

Please sign in to comment.