Skip to content

Commit

Permalink
[Codegen] Fix bug in IGEMM pass for non conv contractions (iree-org#1…
Browse files Browse the repository at this point in the history
…8838)

Adds back a match condition in the ConvolutionToIGEMM pass that got lost
in code cleanup. Checks that the im2col op producer exists, and adds a
test for the failing case.

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Oct 19, 2024
1 parent 3bd455e commit 556c945
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
break;
}
}
if (!im2colOp) {
return rewriter.notifyMatchFailure(genericOp, "no im2colOp producer.");
}

if (getLoweringConfig(genericOp)) {
return rewriter.notifyMatchFailure(genericOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,23 @@ func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1:
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering_config

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func public @no_conv_contraction(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<128x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<128x128xf32>) -> tensor<128x128xf32>
%matmul = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<128x128xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%0 = arith.mulf %in, %in_0 : f32
%1 = arith.addf %0, %out : f32
linalg.yield %1 : f32
} -> tensor<128x128xf32>
return %matmul : tensor<128x128xf32>
}
// CHECK: func.func public @no_conv_contraction
// CHECK: linalg.generic

0 comments on commit 556c945

Please sign in to comment.