Skip to content

Commit

Permalink
Check dependent values of StoreOp in EliminateEmptyTensorsPass (iree-…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry Wu authored Jan 3, 2023
1 parent 39294ea commit ceb9cfd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 27 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"dead_alloc.mlir",
"decompose_linalg_generic.mlir",
"distribute_gpu_shared_memory.mlir",
"eliminate_empty_tensors.mlir",
"erase_hal_descriptor_type.mlir",
"flatten_memref_subspan.mlir",
"fold_affine_min_in_distributed_loops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"dead_alloc.mlir"
"decompose_linalg_generic.mlir"
"distribute_gpu_shared_memory.mlir"
"eliminate_empty_tensors.mlir"
"erase_hal_descriptor_type.mlir"
"flatten_memref_subspan.mlir"
"fold_affine_min_in_distributed_loops.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-eliminate-empty-tensors)" %s | FileCheck %s

// -----
func.func @eliminate_empty_tensors_with_store_op() {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<128x384xf32>>
%1 = tensor.empty() : tensor<32x384xf32>
scf.for %arg0 = %c0 to %c128 step %c32 {
%2 = scf.for %arg1 = %c0 to %c32 step %c8 iter_args(%arg2 = %1) -> (tensor<32x384xf32>) {
scf.yield %arg2 : tensor<32x384xf32>
}
flow.dispatch.tensor.store %2, %0, offsets = [%arg0, 0], sizes = [32, 384], strides = [1, 1] : tensor<32x384xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x384xf32>>
}
return
}

// CHECK-LABEL: @eliminate_empty_tensors_with_store_op
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK: %[[C32:.+]] = arith.constant 32 : index
// CHECK: %[[C128:.+]] = arith.constant 128 : index
// CHECK: %[[SPAN:.+]] = hal.interface.binding.subspan
// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C128]] step %[[C32]]
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SPAN]], offsets = [%[[ARG0]], 0]
// CHECK: %[[RES:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C8]] iter_args(%{{.+}} = %[[LOAD]])
// CHECK: flow.dispatch.tensor.store %[[RES]], %[[SPAN]]
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,20 @@ LogicalResult storeTensorOpAnchoredEmptyTensorEliminationStep(
return eliminateEmptyTensors(
rewriter, op, state,
/*anchorMatchFunc=*/
[&](OpOperand &operand, SmallVector<Value> &) {
return isa<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
auto storeOp =
dyn_cast<IREE::Flow::DispatchTensorStoreOp>(operand.getOwner());
if (!storeOp) return false;
neededValues.push_back(storeOp.getTarget());
neededValues.append(storeOp.getTargetDims().begin(),
storeOp.getTargetDims().end());
neededValues.append(storeOp.getOffsets().begin(),
storeOp.getOffsets().end());
neededValues.append(storeOp.getSizes().begin(),
storeOp.getSizes().end());
neededValues.append(storeOp.getStrides().begin(),
storeOp.getStrides().end());
return true;
},
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
Expand Down Expand Up @@ -424,31 +436,29 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
IREE::Flow::DispatchTensorStoreOp::attachInterface<
DispatchTensorStoreOpInterface>(*ctx);
});
registry.addExtension(
+[](MLIRContext *ctx, IREE::LinalgExt::IREELinalgExtDialect *dialect) {
IREE::LinalgExt::FftOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::FftOp>>(*ctx);
IREE::LinalgExt::PackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
IREE::LinalgExt::UnPackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
IREE::LinalgExt::ReverseOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::ScanOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
IREE::LinalgExt::ScatterOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScatterOp>>(*ctx);
IREE::LinalgExt::SortOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::SortOp>>(*ctx);
IREE::LinalgExt::TopkOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::TopkOp>>(*ctx);
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>(
*ctx);
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>(
*ctx);
});
registry.addExtension(+[](MLIRContext *ctx,
IREE::LinalgExt::IREELinalgExtDialect *dialect) {
IREE::LinalgExt::FftOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::FftOp>>(*ctx);
IREE::LinalgExt::PackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
IREE::LinalgExt::UnPackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
IREE::LinalgExt::ReverseOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::ScanOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
IREE::LinalgExt::ScatterOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScatterOp>>(*ctx);
IREE::LinalgExt::SortOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::SortOp>>(*ctx);
IREE::LinalgExt::TopkOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::TopkOp>>(*ctx);
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>(*ctx);
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx);
});
}

} // namespace iree_compiler
Expand Down

0 comments on commit ceb9cfd

Please sign in to comment.