Skip to content

Commit

Permalink
Diagnose axis out of range for FlattenOp, GatherOp, GatherElementsOp (o…
Browse files Browse the repository at this point in the history
…nnx#1328)

Use the Diagnostic class to diagnose the axis attribute for the ONNX Flatten, Gather and GatherElements operators.

Signed-off-by: Ettore Tiotto <[email protected]>
  • Loading branch information
Ettore Tiotto authored Apr 26, 2022
1 parent 5484243 commit 313928c
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 75 deletions.
1 change: 1 addition & 0 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_onnx_mlir_library(OMONNXOps
ShapeInference/Conv.cpp
ShapeInference/DepthToSpace.cpp
ShapeInference/Expand.cpp
ShapeInference/Flatten.cpp
ShapeInference/Gather.cpp
ShapeInference/GatherElements.cpp
ShapeInference/Gemm.cpp
Expand Down
125 changes: 62 additions & 63 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2829,69 +2829,33 @@ LogicalResult ONNXSplitV11Op::inferShapes(
//===----------------------------------------------------------------------===//

LogicalResult ONNXFlattenOp::verify() {

if (!hasShapeAndRank(input())) {
return success();
}
auto inTy = input().getType().dyn_cast<ShapedType>();
if (!inTy) {
// Cannot verify constraints if the input shape is not yet known.
if (!hasShapeAndRank(input()))
return success();
}

int64_t axisValue = axis();
auto inputShape = inTy.getShape();
auto inputType = input().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputShape.size();
int64_t axisValue = axis();

if (axisValue < -1 * inputRank || axisValue > inputRank) {
return emitOpError("ONNXFlattenOP: axis() value is out of range");
}
// axis attribute must be in the range [-r,r], where r = rank(input).
if (axisValue < -inputRank || axisValue > inputRank)
return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
*this->getOperation(), "axis", axisValue,
onnx_mlir::Diagnostic::Range<int64_t>(-inputRank, inputRank));

return success();
}

LogicalResult ONNXFlattenOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
auto inTy = input().getType().dyn_cast_or_null<RankedTensorType>();
if (!inTy) {
// Cannot infer the output shape if the input shape is not yet known.
if (!hasShapeAndRank(input()))
return success();
}

int64_t axisValue = axis();
auto inputShape = inTy.getShape();
int64_t inputRank = inputShape.size();

SmallVector<int64_t, 2> dims;

// Negative axis is counting dimension from back
if (axisValue < 0)
axisValue = inputRank + axisValue;

// Determine the size of the first dimension of output
int64_t firstDim = 1;
for (auto i = 0; i < axisValue; i++) {
if (inputShape[i] == -1) {
firstDim = -1;
break;
}
firstDim *= inputShape[i];
}
dims.emplace_back(firstDim);

// Determine the size of the second dimension of output
int64_t secondDim = 1;
for (auto i = axisValue; i < inputRank; i++) {
if (inputShape[i] == -1) {
secondDim = -1;
break;
}
secondDim *= inputShape[i];
}
dims.emplace_back(secondDim);

// Set the type of output
getResult().setType(RankedTensorType::get(dims, inTy.getElementType()));

return success();
auto elementType = input().getType().cast<ShapedType>().getElementType();
return shapeHelperInferShapes<ONNXFlattenOpShapeHelper, ONNXFlattenOp,
ONNXFlattenOpAdaptor>(*this, elementType);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3210,9 +3174,28 @@ LogicalResult ONNXTileOp::inferShapes(
// Gather
//===----------------------------------------------------------------------===//

LogicalResult ONNXGatherOp::verify() {
ONNXGatherOpAdaptor operandAdaptor(*this);
if (llvm::any_of(operandAdaptor.getOperands(),
[](const Value &op) { return !hasShapeAndRank(op); }))
return success(); // Won't be able to do any checking at this stage.

auto dataType = operandAdaptor.data().getType().cast<RankedTensorType>();
ArrayRef<int64_t> dataShape = dataType.getShape();
int64_t dataRank = dataShape.size();
int64_t axisValue = axis();

// axis attribute must be in the range [-r,r-1], where r = rank(data).
if (axisValue < -dataRank || axisValue >= dataRank)
return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
*this->getOperation(), "axis", axisValue,
onnx_mlir::Diagnostic::Range<int64_t>(-dataRank, dataRank - 1));

return success();
}

LogicalResult ONNXGatherOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
// Cannot infer the output shape if the operands shape is not yet known.
if (llvm::any_of(this->getOperands(),
[](const Value &op) { return !hasShapeAndRank(op); }))
return success();
Expand Down Expand Up @@ -3823,25 +3806,41 @@ LogicalResult ONNXGreaterOrEqualOp::inferShapes(
ONNXGreaterOrEqualOpAdaptor>(*this, b.getI1Type());
}

//===----------------------------------------------------------------------===//
// ONNXHardmaxOp
//===----------------------------------------------------------------------===//

LogicalResult ONNXHardmaxOp::verify() {
ONNXHardmaxOpAdaptor hmOp = ONNXHardmaxOpAdaptor(*this);
auto input = hmOp.input();
int64_t axis = this->axis();
ONNXHardmaxOpAdaptor operandAdaptor(*this);
Value input = operandAdaptor.input();
if (!hasShapeAndRank(input))
return success(); // Won't be able to do any checking at this stage.

// Verify that axis must be in range [-r, r - 1], where r is the rank of
// input.
if (hasShapeAndRank(input)) {
int64_t rank = input.getType().cast<ShapedType>().getRank();
if (axis < -rank || axis > rank - 1)
return emitOpError("axis value is out of range");
}
// axis attribute must be in the range [-r,r-1], where r = rank(input).
int64_t axisValue = axis();
int64_t inputRank = input.getType().cast<ShapedType>().getRank();
if (axisValue < -inputRank || axisValue >= inputRank)
return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
*this->getOperation(), "axis", axisValue,
onnx_mlir::Diagnostic::Range<int64_t>(-inputRank, inputRank - 1));

return success();
}

LogicalResult ONNXHardmaxOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
getResult().setType(getOperand().getType());
auto inputType = input().getType().cast<ShapedType>();
int64_t inputRank = inputType.getRank();
int64_t axisValue = axis();

// axis attribute must be in the range [-r,r], where r = rank(input).
if (axisValue < -inputRank || axisValue > inputRank)
return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
*this->getOperation(), "axis", axisValue,
onnx_mlir::Diagnostic::Range<int64_t>(-inputRank, inputRank - 1));

getResult().setType(inputType);

return success();
}

Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,7 @@ def ONNXGatherOp:ONNX_Op<"Gather",
return {20};
}
}];
let hasVerifier = 1;
}

def ONNXGatherElementsOp:ONNX_Op<"GatherElements",
Expand Down
57 changes: 57 additions & 0 deletions src/Dialect/ONNX/ShapeInference/Flatten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------- Flatten.cpp - Shape Inference for Flatten Op --------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// This file implements shape inference for the ONNX Flatten Operator.
//
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"

using namespace mlir;

namespace onnx_mlir {

LogicalResult ONNXFlattenOpShapeHelper::computeShape(
ONNXFlattenOpAdaptor operandAdaptor) {
// Get info about input operand.
Value input = operandAdaptor.input();
auto inputType = input.getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank();
int64_t axis = op->axis();
assert(axis >= -inputRank && axis < inputRank && "Invalid inputRank");

// Negative axis means values are counted from the opposite side.
if (axis < 0)
axis += inputRank;

// Compute outputDims.
DimsExpr outputDims = {LiteralIndexExpr(1), LiteralIndexExpr(1)};
for (int64_t i = 0; i < axis; ++i) {
if (inputShape[i] == -1) {
outputDims[0] = QuestionmarkIndexExpr();
break;
}
outputDims[0] = outputDims[0] * LiteralIndexExpr(inputShape[i]);
}

for (int64_t i = axis; i < inputRank; ++i) {
if (inputShape[i] == -1) {
outputDims[1] = QuestionmarkIndexExpr();
break;
}
outputDims[1] = outputDims[1] * LiteralIndexExpr(inputShape[i]);
}

dimsForOutput() = outputDims;
return success();
}

} // namespace onnx_mlir
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ShapeInference/Gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ LogicalResult ONNXGatherOpShapeHelper::computeShape(

int64_t dataRank = dataDims.size();
int64_t axisIndex = op->axis();
assert(axisIndex >= -dataRank && axisIndex < dataRank && "Invalid axisIndex");

// Negative value means counting dimensions from the back.
axisIndex = (axisIndex < 0) ? axisIndex + dataRank : axisIndex;
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ template struct ONNXOpShapeHelper<ONNXConcatOp>;
template struct ONNXOpShapeHelper<ONNXConvOp>;
template struct ONNXOpShapeHelper<ONNXDepthToSpaceOp>;
template struct ONNXOpShapeHelper<ONNXExpandOp>;
template struct ONNXOpShapeHelper<ONNXFlattenOp>;
template struct ONNXOpShapeHelper<ONNXGatherOp>;
template struct ONNXOpShapeHelper<ONNXGatherElementsOp>;
template struct ONNXOpShapeHelper<ONNXGemmOp>;
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ DECLARE_SHAPE_HELPER(ONNXClipOp)
DECLARE_SHAPE_HELPER(ONNXCompressOp)
DECLARE_SHAPE_HELPER(ONNXConcatOp)
DECLARE_SHAPE_HELPER(ONNXDepthToSpaceOp)
DECLARE_SHAPE_HELPER(ONNXFlattenOp)
DECLARE_SHAPE_HELPER(ONNXGatherOp)
DECLARE_SHAPE_HELPER(ONNXGatherElementsOp)
DECLARE_SHAPE_HELPER(ONNXLRNOp)
Expand Down
25 changes: 22 additions & 3 deletions src/Support/Diagnostic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LogicalResult Diagnostic::emitAttributeOutOfRangeError(Operation &op,
const llvm::Twine &attrName, T attrVal, Range<T> validRange) {
static_assert(std::is_arithmetic<T>::value, "Expecting an arithmetic type");

llvm::Twine msg(op.getName().getStringRef() + " ");
Twine msg(op.getName().getStringRef() + " ");
return emitError(op.getLoc(), msg.concat("'" + attrName + "'")
.concat(" value is ")
.concat(std::to_string(attrVal))
Expand All @@ -34,6 +34,23 @@ LogicalResult Diagnostic::emitAttributeOutOfRangeError(Operation &op,
.concat("]"));
}

template <typename T>
LogicalResult Diagnostic::emitInputsMustHaveSameRankError(Operation &op,
const llvm::Twine &inputName1, T rank1, const llvm::Twine &inputName2,
T rank2) {
static_assert(std::is_arithmetic<T>::value, "Expecting an arithmetic type");

llvm::Twine msg(op.getName().getStringRef() + " ");
return emitError(
op.getLoc(), msg.concat("'" + inputName1 + "'")
.concat(" has rank ")
.concat(std::to_string(rank1))
.concat(", '" + inputName2 + "'")
.concat(" has rank ")
.concat(std::to_string(rank2))
.concat(". The two inputs must have the same rank."));
}

LogicalResult Diagnostic::emitOperandHasUnexpectedRankError(Operation &op,
Value &operand, uint64_t operandRank, StringRef expectedRank) {
llvm::Twine msg(op.getName().getStringRef() + ": ");
Expand Down Expand Up @@ -64,7 +81,9 @@ std::string Diagnostic::getName(Value &v) {
}

// Template instantiations - keep at the end of the file.
template LogicalResult Diagnostic::emitAttributeOutOfRangeError(Operation &op,
const llvm::Twine &attrName, int64_t attrVal, Range<int64_t> validRange);
template LogicalResult Diagnostic::emitAttributeOutOfRangeError(
Operation &, const Twine &, int64_t, Range<int64_t>);
template LogicalResult Diagnostic::emitInputsMustHaveSameRankError(
Operation &, const Twine &, int64_t, const Twine &, int64_t);

} // namespace onnx_mlir
7 changes: 7 additions & 0 deletions src/Support/Diagnostic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Twine.h"
#include <type_traits>

namespace onnx_mlir {

Expand Down Expand Up @@ -44,6 +45,12 @@ class Diagnostic {
static mlir::LogicalResult emitAttributeOutOfRangeError(mlir::Operation &op,
const llvm::Twine &attrName, T attrVal, Range<T> validRange);

/// Verifies whether 2 inputs have the same rank.
template <typename T>
static mlir::LogicalResult emitInputsMustHaveSameRankError(
mlir::Operation &op, const llvm::Twine &inputName1, T rank1,
const llvm::Twine &inputName2, T rank2);

/// Diagnostic message for operand with unexpected rank.
static mlir::LogicalResult emitOperandHasUnexpectedRankError(
mlir::Operation &op, mlir::Value &operand, uint64_t operandRank,
Expand Down
Loading

0 comments on commit 313928c

Please sign in to comment.