Skip to content

Commit

Permalink
Lowering to KRNL Equal Operation (onnx#2401)
Browse files Browse the repository at this point in the history
* Testing Equal Op version 19
---------

Co-authored-by: Megan Hampton <[email protected]>
  • Loading branch information
hamptonm1 and MegoHam21 authored Aug 4, 2023
1 parent cd493ae commit 1bbc6f8
Show file tree
Hide file tree
Showing 11 changed files with 837 additions and 623 deletions.
182 changes: 91 additions & 91 deletions docs/Dialects/krnl.md

Large diffs are not rendered by default.

1,048 changes: 547 additions & 501 deletions docs/Dialects/onnx.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitatio
| **DynamicQuantizeLinear** |11 | | |
| **Einsum** |12 |Limited to the types supported by ReduceSum and MatMul (which we decompose to in most cases) which exclude integers with width < 32. | |
| **Elu** |6 | | |
| **Equal** |13 | | |
| **Equal** |19 | | |
| **Erf** |13 | | |
| **Exp** |13 | | |
| **Expand** |13 | | |
Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ op_dialect_version_map_["Dropout"] = {13};
op_dialect_version_map_["DynamicQuantizeLinear"] = {11};
op_dialect_version_map_["Einsum"] = {12};
op_dialect_version_map_["Elu"] = {6};
op_dialect_version_map_["Equal"] = {13};
op_dialect_version_map_["Equal"] = {19};
op_dialect_version_map_["Erf"] = {13};
op_dialect_version_map_["Exp"] = {13};
op_dialect_version_map_["Expand"] = {13};
Expand Down
39 changes: 29 additions & 10 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ static void CheckIfCustomScalarOpIsSupported(Type elementType) {
if (actualElementType.isa<mlir::IntegerType>()) {
if constexpr (std::is_same<ScalarIOp<Op>, CustomScalarOp>::value)
return;
llvm_unreachable("this op does not supports custom scalar for integers");
llvm_unreachable("this op does not support custom scalar for integers");
}
if (actualElementType.isa<mlir::FloatType>()) {
if constexpr (std::is_same<ScalarFOp<Op>, CustomScalarOp>::value)
return;
llvm_unreachable("this op does not supports custom scalar for floats");
llvm_unreachable("this op does not support custom scalar for floats");
}
}

Expand Down Expand Up @@ -393,6 +393,7 @@ Value emitScalarOpFor<ONNXCastOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {

CheckIfCustomScalarOpIsSupported<ONNXCastOp>(elementType);
// TODO: currently don't support String to * or * to String
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
return create.math.cast(elementType, scalarOperands[0]);
Expand Down Expand Up @@ -958,7 +959,7 @@ Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXLessOp
// Scalar binary ops for lowering ONNXLessOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXLessOp> {
Expand All @@ -978,7 +979,7 @@ Value emitScalarOpFor<ONNXLessOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXLessOrEqualOp
// Scalar binary ops for lowering ONNXLessOrEqualOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXLessOrEqualOp> {
Expand All @@ -998,7 +999,7 @@ Value emitScalarOpFor<ONNXLessOrEqualOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXGreaterOp
// Scalar binary ops for lowering ONNXGreaterOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXGreaterOp> {
Expand All @@ -1018,7 +1019,7 @@ Value emitScalarOpFor<ONNXGreaterOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXGreaterOrEqualOp
// Scalar binary ops for lowering ONNXGreaterOrEqualOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXGreaterOrEqualOp> {
Expand All @@ -1038,7 +1039,7 @@ Value emitScalarOpFor<ONNXGreaterOrEqualOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXEqualOp
// Scalar binary ops for lowering ONNXEqualOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXEqualOp> {
Expand All @@ -1050,11 +1051,29 @@ template <>
Value emitScalarOpFor<ONNXEqualOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {

CheckIfCustomScalarOpIsSupported<ONNXEqualOp>(elementType);
Value results;
Value lhs = scalarOperands[0];
Value rhs = scalarOperands[1];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
return create.math.eq(lhs, rhs);
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Type inputElemType = getElementType(lhs.getType());

// If the two input values are a string then we want to use the krnlstrncmp.
// However, if the input values are a float or an int we can simply use the
// equal function.
if (inputElemType.isa<krnl::StringType>()) {
Value strlenRes = create.krnl.strlen(lhs);
Value strncmpRes = create.krnl.strncmp(lhs, rhs, strlenRes);
// Confirm the strncmp is indeed valid. strncmp returns a value of 0 if the
// strings are equal. So we need to verify the returned results is equal to
// 0.
Value zeroVal = create.math.constant(strncmpRes.getType(), 0);
results = create.math.eq(strncmpRes, zeroVal);
} else {
results = create.math.eq(lhs, rhs);
}
return results;
}

//===----------------------------------------------------------------------===//
Expand All @@ -1078,7 +1097,7 @@ Value emitScalarOpFor<ONNXNotOp>(ConversionPatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXModOp
// Scalar binary ops for lowering ONNXModOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXModOp> {
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -2052,8 +2052,8 @@ def ONNXEqualOp:ONNX_Op<"Equal",

This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md).
}];
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$A,
AnyTypeOf<[TensorOf<[I1]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$B);
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[StringType]>]>:$A,
AnyTypeOf<[TensorOf<[I1]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[StringType]>]>:$B);
let results = (outs TensorOf<[I1]>:$C);
let builders = [
OpBuilder<(ins "Value":$A, "Value":$B), [{
Expand Down
Loading

0 comments on commit 1bbc6f8

Please sign in to comment.