Skip to content

Commit

Permalink
Updating the MHLOToLinalgOnTensors pass to support ml_program. (iree-…
Browse files Browse the repository at this point in the history
…org#10896)

This pass should not be doing signedness conversion like this and I'll
file an issue for cleanup - for now this adds the new module-level
global op so we can convert ml_program+mhlo.

Fixes iree-org#10834.
  • Loading branch information
benvanik authored Oct 25, 2022
1 parent 97c1ada commit cfedd31
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Utils/ConversionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -63,77 +64,6 @@ struct PrimitiveTypeConverter : public TypeConverter {
virtual Type getTargetType(SourceType type) = 0;
};

// Returns |oldAttr| converted to its new type via |typeConverter|, if needed.
static Attribute convertAttribute(Location loc, Attribute oldAttr,
TypeConverter &typeConverter) {
// Type attributes get their nested type converted.
if (auto oldTypeAttr = oldAttr.dyn_cast<TypeAttr>()) {
return TypeAttr::get(typeConverter.convertType(oldTypeAttr.getValue()));
}

// Return the same attribute if it doesn't have a type.
auto typedOldAttr = oldAttr.dyn_cast<TypedAttr>();
if (!typedOldAttr) return oldAttr;

// Convert the attribute type - if it's the same then it's already legal.
auto oldType = typedOldAttr.getType();
auto newType = typeConverter.convertType(oldType);
if (oldType == newType) return typedOldAttr;

if (auto intAttr = typedOldAttr.dyn_cast<IntegerAttr>()) {
APInt value = intAttr.getValue();
if (newType.isSignedInteger()) {
value = value.truncSSat(newType.getIntOrFloatBitWidth());
} else if (newType.isUnsignedInteger()) {
value = value.truncUSat(newType.getIntOrFloatBitWidth());
} else {
value = value.trunc(newType.getIntOrFloatBitWidth());
}
return IntegerAttr::get(newType, value);
} else if (auto floatAttr = typedOldAttr.dyn_cast<FloatAttr>()) {
auto newFloatType = newType.cast<FloatType>();
APFloat value = floatAttr.getValue();
bool losesInfo = false;
value.convert(newFloatType.getFloatSemantics(), APFloat::rmTowardZero,
&losesInfo);
return FloatAttr::get(newType, value);
} else if (auto splatAttr = typedOldAttr.dyn_cast<SplatElementsAttr>()) {
// NOTE: splats are also dense but this way we avoid needing to convert the
// same splat value N times.
return SplatElementsAttr::get(
newType.cast<ShapedType>(),
convertAttribute(loc, splatAttr.getSplatValue<Attribute>(),
typeConverter));
} else if (auto denseAttr = typedOldAttr.dyn_cast<DenseIntElementsAttr>()) {
auto newElementType = newType.cast<ShapedType>().getElementType();
auto newElementBitWidth = newElementType.getIntOrFloatBitWidth();
if (newElementType.isSignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncSSat(newElementBitWidth);
});
} else if (newElementType.isUnsignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncUSat(newElementBitWidth);
});
} else {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.trunc(newElementBitWidth);
});
}
} else if (auto denseAttr = typedOldAttr.dyn_cast<DenseFPElementsAttr>()) {
auto newElementType =
newType.cast<ShapedType>().getElementType().cast<FloatType>();
const auto &newFloatSemantics = newElementType.getFloatSemantics();
return denseAttr.mapValues(newElementType, [&](APFloat src) {
bool losesInfo = false;
src.convert(newFloatSemantics, APFloat::rmTowardZero, &losesInfo);
return src.bitcastToAPInt();
});
}

return oldAttr;
}

// Tries to completely convert a generic Operation.
// This will process attributes, result types, and nested regions.
struct GenericTypeConversionPattern : public ConversionPattern {
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/InputConversion/MHLO/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/InputConversion/Common",
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
"@llvm-project//llvm:Support",
Expand All @@ -83,6 +84,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MLProgramDialect",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_cc_library(
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRMLProgramDialect
MLIRMathDialect
MLIRMemRefDialect
MLIRMhloUtils
Expand Down Expand Up @@ -95,6 +96,7 @@ iree_cc_library(
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::InputConversion::Common
iree::compiler::Utils
tensorflow::external_mhlo_includes
DEFINES
"IREE_HAVE_MHLO_INPUT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/MHLO/Rewriters.h"
#include "iree/compiler/Utils/ConversionUtils.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
Expand All @@ -27,6 +28,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -261,6 +263,30 @@ class BuiltinFuncOpPattern : public OpConversionPattern<func::FuncOp> {
}
};

class GlobalOpPattern : public OpConversionPattern<ml_program::GlobalOp> {
public:
using OpConversionPattern<ml_program::GlobalOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ml_program::GlobalOp globalOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = globalOp.getType();
auto newType = getTypeConverter()->convertType(oldType);
if (newType == oldType) return failure();
if (!newType) {
return rewriter.notifyMatchFailure(globalOp,
"result type conversion failed");
}
rewriter.updateRootInPlace(globalOp, [&]() {
globalOp.setType(newType);
if (auto oldValue = globalOp.getValueAttr()) {
globalOp.setValueAttr(
convertAttribute(globalOp.getLoc(), oldValue, *getTypeConverter()));
}
});
return success();
}
};

class GenericTypeConvert : public ConversionPattern {
public:
GenericTypeConvert(StringRef rootName, TypeConverter &converter,
Expand Down Expand Up @@ -334,6 +360,9 @@ struct ConvertMHLOToLinalgOnTensorsPass
patterns);
populateMHLOComplexToRealPatterns(context, *typeConverter, patterns);

// TODO(*): expose patterns that do this much better from
// iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp

// Structural patterns (functions, cfg, terminators).
patterns.insert<BuiltinFuncOpPattern>(*typeConverter, context);
patterns.insert<GenericTypeConvert>(func::ReturnOp::getOperationName(),
Expand All @@ -344,6 +373,14 @@ struct ConvertMHLOToLinalgOnTensorsPass
*typeConverter, context);
patterns.insert<GenericTypeConvert>(cf::BranchOp::getOperationName(),
*typeConverter, context);
patterns.insert<GlobalOpPattern>(*typeConverter, context);
patterns.insert<GenericTypeConvert>(
ml_program::GlobalLoadOp::getOperationName(), *typeConverter, context);
patterns.insert<GenericTypeConvert>(
ml_program::GlobalLoadConstOp::getOperationName(), *typeConverter,
context);
patterns.insert<GenericTypeConvert>(
ml_program::GlobalStoreOp::getOperationName(), *typeConverter, context);

ConversionTarget target(getContext());
auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
Expand Down Expand Up @@ -375,6 +412,10 @@ struct ConvertMHLOToLinalgOnTensorsPass
}
return true;
});
target.addDynamicallyLegalOp<ml_program::GlobalOp>(
[&](ml_program::GlobalOp op) {
return typeConverter->isLegal(op.getType());
});

// Let the rest fall through.
target.addLegalDialect<BuiltinDialect>();
Expand All @@ -400,7 +441,7 @@ void populateMHLOToLinalgOnTensorsConversionPatterns(
typeConverter, context, PatternBenefit(1000));
}

std::unique_ptr<OperationPass<func::FuncOp>> createMHLOToLinalgOnTensorsPass() {
std::unique_ptr<OperationPass<ModuleOp>> createMHLOToLinalgOnTensorsPass() {
return std::make_unique<ConvertMHLOToLinalgOnTensorsPass>();
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/InputConversion/MHLO/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void buildMHLOInputConversionPassPipeline(OpPassManager &passManager) {
passManager.addNestedPass<func::FuncOp>(
mhlo::createLegalizeShapeComputationsPass());
passManager.addNestedPass<func::FuncOp>(createConvertMHLOToLinalgExtPass());
passManager.addNestedPass<func::FuncOp>(createMHLOToLinalgOnTensorsPass());
passManager.addPass(createMHLOToLinalgOnTensorsPass());
// Ensure conversion completed.
passManager.addPass(createReconcileUnrealizedCastsPass());

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/InputConversion/MHLO/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createFlattenTuplesInCFGPass();
//------------------------------------------------------------------------------

/// Creates XLA-HLO to Linalg on tensors transformation pass.
std::unique_ptr<OperationPass<func::FuncOp>> createMHLOToLinalgOnTensorsPass();
std::unique_ptr<OperationPass<ModuleOp>> createMHLOToLinalgOnTensorsPass();

/// Creates XLA-HLO to LinalgExt pass.
std::unique_ptr<OperationPass<func::FuncOp>> createConvertMHLOToLinalgExtPass();
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/InputConversion/MHLO/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
include "mlir/Pass/PassBase.td"

def ConvertMHLOToLinalgOnTensors :
Pass<"iree-mhlo-to-linalg-on-tensors", "func::FuncOp"> {
Pass<"iree-mhlo-to-linalg-on-tensors", "ModuleOp"> {
let summary = "Convert from XLA-HLO ops to Linalg ops on tensors";
let constructor = "mlir::iree_compiler::MHLO::createMHLOToLinalgOnTensorsPass()";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@ func.func @concatenate(%arg0: tensor<2x2xi32>, %arg1: tensor<2x4xi32>) -> tensor
// CHECK: %[[T1:.+]] = tensor.insert_slice %[[CST]] into %[[T0]][0, 2] [2, 3] [1, 1]
// CHECK: %[[T2:.+]] = tensor.insert_slice %[[ARG1]] into %[[T1]][0, 5] [2, 4] [1, 1]
// CHECK: return %[[T2]]

// -----

// CHECK: ml_program.global private mutable @variable(dense<0> : tensor<2xi32>) : tensor<2xi32>
ml_program.global private mutable @variable(dense<0> : tensor<2xui32>) : tensor<2xui32>
// CHECK: func.func @global_types() -> tensor<2xi32>
func.func @global_types() -> tensor<2xui32> {
// CHECK-NEXT: %[[VALUE:.+]] = ml_program.global_load @variable : tensor<2xi32>
%0 = ml_program.global_load @variable : tensor<2xui32>
// CHECK: return %[[VALUE]] : tensor<2xi32>
return %0 : tensor<2xui32>
}
70 changes: 70 additions & 0 deletions compiler/src/iree/compiler/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,75 @@ LogicalResult verifyAllOperationsAreLegal(Operation *op,
return failure();
}

Attribute convertAttribute(Location loc, Attribute oldAttr,
TypeConverter &typeConverter) {
// Type attributes get their nested type converted.
if (auto oldTypeAttr = oldAttr.dyn_cast<TypeAttr>()) {
return TypeAttr::get(typeConverter.convertType(oldTypeAttr.getValue()));
}

// Return the same attribute if it doesn't have a type.
auto typedOldAttr = oldAttr.dyn_cast<TypedAttr>();
if (!typedOldAttr) return oldAttr;

// Convert the attribute type - if it's the same then it's already legal.
auto oldType = typedOldAttr.getType();
auto newType = typeConverter.convertType(oldType);
if (oldType == newType) return typedOldAttr;

if (auto intAttr = typedOldAttr.dyn_cast<IntegerAttr>()) {
APInt value = intAttr.getValue();
if (newType.isSignedInteger()) {
value = value.truncSSat(newType.getIntOrFloatBitWidth());
} else if (newType.isUnsignedInteger()) {
value = value.truncUSat(newType.getIntOrFloatBitWidth());
} else {
value = value.trunc(newType.getIntOrFloatBitWidth());
}
return IntegerAttr::get(newType, value);
} else if (auto floatAttr = typedOldAttr.dyn_cast<FloatAttr>()) {
auto newFloatType = newType.cast<FloatType>();
APFloat value = floatAttr.getValue();
bool losesInfo = false;
value.convert(newFloatType.getFloatSemantics(), APFloat::rmTowardZero,
&losesInfo);
return FloatAttr::get(newType, value);
} else if (auto splatAttr = typedOldAttr.dyn_cast<SplatElementsAttr>()) {
// NOTE: splats are also dense but this way we avoid needing to convert the
// same splat value N times.
return SplatElementsAttr::get(
newType.cast<ShapedType>(),
convertAttribute(loc, splatAttr.getSplatValue<Attribute>(),
typeConverter));
} else if (auto denseAttr = typedOldAttr.dyn_cast<DenseIntElementsAttr>()) {
auto newElementType = newType.cast<ShapedType>().getElementType();
auto newElementBitWidth = newElementType.getIntOrFloatBitWidth();
if (newElementType.isSignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncSSat(newElementBitWidth);
});
} else if (newElementType.isUnsignedInteger()) {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.truncUSat(newElementBitWidth);
});
} else {
return denseAttr.mapValues(newElementType, [&](APInt src) {
return src.trunc(newElementBitWidth);
});
}
} else if (auto denseAttr = typedOldAttr.dyn_cast<DenseFPElementsAttr>()) {
auto newElementType =
newType.cast<ShapedType>().getElementType().cast<FloatType>();
const auto &newFloatSemantics = newElementType.getFloatSemantics();
return denseAttr.mapValues(newElementType, [&](APFloat src) {
bool losesInfo = false;
src.convert(newFloatSemantics, APFloat::rmTowardZero, &losesInfo);
return src.bitcastToAPInt();
});
}

return oldAttr;
}

} // namespace iree_compiler
} // namespace mlir
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ namespace iree_compiler {
LogicalResult verifyAllOperationsAreLegal(Operation *op,
const ConversionTarget &target);

// Returns |oldAttr| converted to its new type via |typeConverter|, if needed.
Attribute convertAttribute(Location loc, Attribute oldAttr,
TypeConverter &typeConverter);

} // namespace iree_compiler
} // namespace mlir

Expand Down

0 comments on commit cfedd31

Please sign in to comment.