Skip to content

Commit

Permalink
Add support for lowering mhlo.iota to Linalg.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322799853
Change-Id: I77aa951ebbd707c54af7dd2d6b031b5f22f75178
  • Loading branch information
hanhanW authored and tensorflower-gardener committed Jul 23, 2020
1 parent 9ce9e77 commit 42d68f8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
};

class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
template <typename OpTy, bool isLHLO = true>
class IotaConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
OpTy iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
if (!resultMemrefType) return failure();
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
if (!resultShapedType) return failure();

auto resultElementType = resultMemrefType.getElementType();
auto resultElementType = resultShapedType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure();

// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
unsigned nloops = resultShapedType.getRank();

rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), ArrayRef<Type>{}, args,
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
0, // args_in
1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
Expand All @@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.isa<FloatType>()) {
if (resultElementType.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType);
}
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
});

rewriter.replaceOp(iotaOp, llvm::None);
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
return success();
}
};
Expand Down Expand Up @@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
Expand Down Expand Up @@ -870,36 +872,37 @@ namespace mhlo {

void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
patterns
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
}

std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

// -----

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32>
}
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32

0 comments on commit 42d68f8

Please sign in to comment.