diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index aaa4ad3e4f99..0cc1f284d30f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -143,6 +143,7 @@ iree_compiler_cc_library( "TileDispatchUsingForall.cpp", "TileDispatchUsingInterface.cpp", "TileSizeSelection.cpp", + "TileSwizzle.cpp", "TypePropagationPass.cpp", "UserConfig.cpp", "VectorizeMemrefCopy.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index d828e083cf5d..3a94dcd74c20 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -135,6 +135,7 @@ iree_cc_library( "TileDispatchUsingForall.cpp" "TileDispatchUsingInterface.cpp" "TileSizeSelection.cpp" + "TileSwizzle.cpp" "TypePropagationPass.cpp" "UserConfig.cpp" "VectorizeMemrefCopy.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp index 9b8468b82310..96d75940e56a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp @@ -210,10 +210,12 @@ bool isNarrowNResult(EncodingAttr encoding) { } SmallVector -getExpandedTileShape(SmallVector> expandShape) { +getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) { SmallVector result; - for (auto expandShapeDim : expandShape) { - result.append(expandShapeDim); + for (auto e : expandShape) { + for (auto d : e) { + result.push_back(d.size); + } } return result; } diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index b7d75c9516e5..290502849d4c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -143,7 +143,7 @@ bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding); /// Concatenates the vectors. SmallVector -getExpandedTileShape(SmallVector> expandShape); +getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 4ffa632e51a9..1174e81e05a1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -37,84 +37,6 @@ namespace mlir::iree_compiler { #define GEN_PASS_DEF_GPUMATERIALIZEDEVICEENCODINGPASS #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" -/// Returns the index of the dimension whose flattened size (flattening inner -/// dimensions into it) matches the given `targetSize`. This is used to compute -/// interleaving indices. -/// -/// Example: -/// Input shape = [16, 8, 4, 4] -/// Input targetSize = 16 -/// -> Return 2, because the tail of the shape starting at index 2 is [4, 4], -/// whose product equals targetSize. -static int64_t getDimIdxForTargetSize(ArrayRef shape, - int64_t targetSize) { - int interleaveAt = 0; - int size = 1; - for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) { - assert(size <= targetSize); - assert((targetSize % size) == 0); - if (size == targetSize) { - break; - } - size *= shape[interleaveAt]; - } - return interleaveAt; -} - -/// Generates the swizzle for the full data-tiled-mma tile, including all the -/// relevant unrolling factors. -static TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, - IREE::GPU::MMAFragment fragment) { - auto [AType, BType, CType] = mma.getABCElementTypes(); - int ABits = AType.getIntOrFloatBitWidth(); - int BBits = BType.getIntOrFloatBitWidth(); - // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded. - const int targetPreferredLoadBitWidth = 128; - auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment); - switch (fragment) { - case IREE::GPU::MMAFragment::Lhs: - // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1). - // Unroll on K with interleaving, then on M. - if (mma.getUnrollK() > 1) { - unroll(swizzle, 1, mma.getUnrollK()); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits)); - interleave(swizzle, 1, interleavingIdx); - } - if (mma.getUnrollM() > 1) { - unroll(swizzle, 0, mma.getUnrollM()); - } - break; - case IREE::GPU::MMAFragment::Rhs: - // B-matrix (RHS). Since the pack ops already took care of transposing B, - // source dimensions are N (index 0) and K (index 1). - // Unroll on K with interleaving, then on N. - if (mma.getUnrollK() > 1) { - unroll(swizzle, 1, mma.getUnrollK()); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits)); - interleave(swizzle, 1, interleavingIdx); - } - if (mma.getUnrollN() > 1) { - unroll(swizzle, 0, mma.getUnrollN()); - } - break; - case IREE::GPU::MMAFragment::Acc: - // C-matrix (accumulator). Source dimensions are M (index 0) and N (index - // 1). Unroll on N, then on M. - if (mma.getUnrollN() > 1) { - unroll(swizzle, 1, mma.getUnrollN()); - } - if (mma.getUnrollM() > 1) { - unroll(swizzle, 0, mma.getUnrollM()); - } - break; - } - return swizzle; -} - static bool hasIntrinsic(IREE::GPU::TargetAttr target, IREE::GPU::MMAIntrinsic intrinsic) { for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -133,13 +55,16 @@ chooseDataTiledMMAAttr(TypeRange elementTypes, IREE::GPU::TargetAttr target) { Type lhs = elementTypes[0]; Type rhs = elementTypes[1]; Type out = elementTypes[2]; - auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollN, + auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToThreads, + int unrollN, int unrollNToThreads, int unrollK) -> std::optional { if (!hasIntrinsic(target, intrinsic)) { return std::nullopt; } auto candidate = DataTiledMMAAttr::get( - ctx, MMAIntrinsicAttr::get(ctx, intrinsic), unrollM, unrollN, unrollK); + ctx, MMAIntrinsicAttr::get(ctx, intrinsic), /*unroll_m=*/unrollM, + /*unroll_m_to_subgroups=*/unrollMToThreads, /*unroll_n=*/unrollN, + /*unroll_n_to_subgroups=*/unrollNToThreads, /*unroll_k=*/unrollK); auto [candidateLhs, candidateRhs, candidateOut] = candidate.getABCElementTypes(); if (candidateLhs != lhs || candidateRhs != rhs || candidateOut != out) { @@ -147,13 +72,13 @@ chooseDataTiledMMAAttr(TypeRange elementTypes, IREE::GPU::TargetAttr target) { } return candidate; }; - if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 8, 4)) { + if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x4_F32, 8, 1, 2, 4, 4)) { return m; } - if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 8, 2)) { + if (auto m = match(MMAIntrinsic::MFMA_F32_16x16x16_F16, 8, 1, 2, 4, 2)) { return m; } - if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 8, 2)) { + if (auto m = match(MMAIntrinsic::MFMA_I32_16x16x32_I8, 8, 1, 2, 4, 2)) { return m; } // Fallback - no architecture-optimized tile size for this case. @@ -220,7 +145,7 @@ struct GPUMaterializeDeviceEncodingPass final SmallVector getReassociationIndices(int outerDims, - SmallVector> expandShape) { + const TileSwizzle::ExpandShapeType &expandShape) { SmallVector result; int expandedIdx = 0; for (int i = 0; i < outerDims; ++i) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp index b225e691fcea..94335c47dddb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h" -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" namespace mlir::iree_compiler { @@ -13,7 +12,7 @@ namespace mlir::iree_compiler { // dimensions to expanded dimensions, returns the index of the first expanded // dimension corresponding to the given source dimension index. static int64_t -getExpandedDimFirstIdx(const SmallVector> &expandShape, +getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape, int64_t srcIndex) { int dstIndexFirst = 0; for (int i = 0; i < srcIndex; ++i) { @@ -22,14 +21,17 @@ getExpandedDimFirstIdx(const SmallVector> &expandShape, return dstIndexFirst; } -void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor) { +void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor, + TileSwizzle::Dim::Kind kind) { assert(unrollFactor > 1); int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex); - + TileSwizzle::Dim unrollDim; + unrollDim.size = unrollFactor; + unrollDim.kind = kind; // The new unrolling dimension is inserted at the start of the expandShape // dimensions group corresponding to srcIndex. swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(), - unrollFactor); + unrollDim); // Since we are not interleaving here, generating side-by-side copies of the // original layout, the new unrolling dimension is the new outermost // dimension. Existing entries get shifted to make room for it. @@ -97,7 +99,10 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, // shape expansion for now. TileSwizzle swizzle; for (auto t : layout.thread) { - swizzle.expandShape.push_back({t}); + TileSwizzle::Dim dim; + dim.size = t; + dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`. + swizzle.expandShape.push_back({dim}); } // The layout strides decide the initial swizzle.permutation. // Some WMMA intrinsics have tstrides=0 values, assert on that as that @@ -112,9 +117,12 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, // Deal with any element size greater than 1 by inserting it innermost. // Notice that this is similar to the unroll() function, just creating an // inner dimension instead of an outer dimension. - for (int i = 0; i < layout.element.size(); ++i) { - if (layout.element[i] != 1) { - swizzle.expandShape[i].push_back(layout.element[i]); + for (auto [i, e] : llvm::enumerate(layout.element)) { + if (e != 1) { + TileSwizzle::Dim dim; + dim.size = e; + dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`. + swizzle.expandShape[i].push_back(dim); int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1; for (auto &p : swizzle.permutation) { p += (p >= newIndex); @@ -125,13 +133,105 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, // Deal with any outer size greater than 1 as just a call to unroll. // Iterate over dims in reverse order because we are creating a new outermost // dimension each time. - for (int i = layout.outer.size() - 1; i >= 0; --i) { - if (layout.outer[i] != 1) { - unroll(swizzle, i, layout.outer[i]); + for (auto [i, o] : llvm::enumerate(layout.outer)) { + if (o != 1) { + // `layout.outer` means additional Internal dimensions, just like + // `layout.element`, just swizzled outermost. + unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal); } } return swizzle; } +// Returns the index of the dimension whose flattened size (flattening inner +// dimensions into it) matches the given `targetSize`. This is used to compute +// interleaving indices. +// +// Example: +// Input shape = [16, 8, 4, 4] +// Input targetSize = 16 +// -> Return 2, because the tail of the shape starting at index 2 is [4, 4], +// whose product equals targetSize. +static int64_t +getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape, + int64_t targetSize) { + int interleaveAt = 0; + int size = 1; + for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) { + assert(size <= targetSize); + assert((targetSize % size) == 0); + if (size == targetSize) { + break; + } + size *= shape[interleaveAt].size; + } + return interleaveAt; +} + +TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, + IREE::GPU::MMAFragment fragment) { + auto [AType, BType, CType] = mma.getABCElementTypes(); + int ABits = AType.getIntOrFloatBitWidth(); + int BBits = BType.getIntOrFloatBitWidth(); + // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded. + const int targetPreferredLoadBitWidth = 128; + auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment); + using Kind = TileSwizzle::Dim::Kind; + switch (fragment) { + case IREE::GPU::MMAFragment::Lhs: + // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1). + // Unroll on K with interleaving, then on M. + if (mma.getUnrollK() > 1) { + unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); + int interleavingIdx = getDimIdxForTargetSize( + swizzle.expandShape[1], + targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits)); + interleave(swizzle, 1, interleavingIdx); + } + if (mma.getUnrollM() > 1) { + unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic); + } + if (mma.getUnrollMToSubgroups() > 1) { + unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread); + } + break; + case IREE::GPU::MMAFragment::Rhs: + // B-matrix (RHS). Since the pack ops already took care of transposing B, + // source dimensions are N (index 0) and K (index 1). + // Unroll on K with interleaving, then on N. + if (mma.getUnrollK() > 1) { + unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); + int interleavingIdx = getDimIdxForTargetSize( + swizzle.expandShape[1], + targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits)); + interleave(swizzle, 1, interleavingIdx); + } + if (mma.getUnrollN() > 1) { + unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic); + } + if (mma.getUnrollNToSubgroups() > 1) { + unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread); + } + break; + case IREE::GPU::MMAFragment::Acc: + // C-matrix (accumulator). Source dimensions are M (index 0) and N (index + // 1). Unroll on N, then on M. + if (mma.getUnrollN() > 1) { + unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic); + } + if (mma.getUnrollNToSubgroups() > 1) { + unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread); + } + if (mma.getUnrollM() > 1) { + unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic); + } + if (mma.getUnrollMToSubgroups() > 1) { + unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread); + } + break; + } + return swizzle; +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h index fc5af79c9485..fc79bf0b6629 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h @@ -8,6 +8,7 @@ #define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_ #include "iree/compiler/Codegen/Common/TileSwizzle.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" namespace mlir::iree_compiler { @@ -17,17 +18,26 @@ namespace mlir::iree_compiler { TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, IREE::GPU::MMAFragment fragment); +// Returns the swizzle for the full data-tiled-mma tile, including all the +// relevant unrolling factors. +TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, + IREE::GPU::MMAFragment fragment); + // Unrolls the dimension given by `srcIndex` by the given `unrollFactor`. // This is not interleaving layouts. The layout will consist of multiple copies // of the input tile, side by side. // +// The enum parameter `kind` initializes the corresponding member on the newly +// created TileSwizzle::Dim. +// // Example: // Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] } // Input srcIndex = 1 // Input unrollFactor = 4 // -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] } // -void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor); +void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor, + TileSwizzle::Dim::Kind kind); // Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to // move permutation[0], the outer-most dimension (which the unroll() function diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir index ac28e845fd83..8e5927bf16ed 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir @@ -128,11 +128,11 @@ func.func @set_encoding_RHS_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-SAME: inner_tiles = [128, 16] // CHECK-SAME: : tensor<255x513xf32> -> tensor<5x16x128x16xf32> // CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]] -// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x8x16x4x4xf32> +// CHECK-SAME : tensor<5x16x128x16xf32> into tensor<5x16x4x2x16x4x4xf32> // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x8x16x4x4xf32>) -// CHECK-SAME: outs({{.*}} : tensor<5x16x8x4x16x4xf32>) -// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4] +// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x16x4x2x16x4x4xf32>) +// CHECK-SAME: outs({{.*}} : tensor<5x16x4x2x4x16x4xf32>) +// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5] // CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]] // ----- @@ -161,11 +161,11 @@ func.func @set_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-SAME: inner_tiles = [128, 128] // CHECK-SAME: : tensor<255x513xf32> -> tensor<2x5x128x128xf32> // CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]] -// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x8x16xf32> +// CHECK-SAME : tensor<2x5x128x128xf32> into tensor<2x5x8x4x4x4x2x16xf32> // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xf32>) -// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xf32>) -// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4] +// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xf32>) +// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xf32>) +// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4] // CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]] // ----- @@ -189,11 +189,11 @@ func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xf32>) -// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xf32>) -// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5] +// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xf32>) +// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xf32>) +// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6] // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] -// CHECK-SAME: : tensor<2x5x8x4x4x8x16xf32> into tensor<2x5x128x128xf32> +// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xf32> into tensor<2x5x128x128xf32> // CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]] // CHECK-SAME: outer_dims_perm = [0, 1] // CHECK-SAME: inner_dims_pos = [0, 1] @@ -232,11 +232,11 @@ func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32() { } // CHECK-LABEL: func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32 // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%{{.+}} : tensor) -// CHECK-SAME: outs({{.*}} : tensor) -// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5] +// CHECK-SAME: ins(%{{.+}} : tensor) +// CHECK-SAME: outs({{.*}} : tensor) +// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6] // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] -// CHECK-SAME: : tensor into tensor +// CHECK-SAME: : tensor into tensor // CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]] // CHECK-SAME: outer_dims_perm = [0, 1] // CHECK-SAME: inner_dims_pos = [0, 1] @@ -295,12 +295,12 @@ func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) // CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) // CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor -// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor -// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] -// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] @@ -365,11 +365,11 @@ func.func @set_encoding_RHS_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-SAME: inner_tiles = [128, 64] // CHECK-SAME: : tensor<255x513xi8> -> tensor<5x4x128x64xi8> // CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]] -// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x8x16x2x4x8xi8> +// CHECK-SAME : tensor<5x4x128x64xi8> into tensor<5x4x4x2x16x2x4x8xi8> // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x8x16x2x4x8xi8>) -// CHECK-SAME: outs({{.*}} : tensor<5x4x8x4x16x2x8xi8>) -// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 4, 6] +// CHECK-SAME: ins(%[[EXPAND]] : tensor<5x4x4x2x16x2x4x8xi8>) +// CHECK-SAME: outs({{.*}} : tensor<5x4x4x2x4x16x2x8xi8>) +// CHECK-SAME: permutation = [0, 1, 2, 3, 6, 4, 5, 7] // CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]] // ----- @@ -398,11 +398,11 @@ func.func @set_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-SAME: inner_tiles = [128, 128] // CHECK-SAME: : tensor<255x513xi32> -> tensor<2x5x128x128xi32> // CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]] -// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x8x16xi32> +// CHECK-SAME : tensor<2x5x128x128xi32> into tensor<2x5x8x4x4x4x2x16xi32> // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x8x16xi32>) -// CHECK-SAME: outs({{.*}} : tensor<2x5x8x8x4x16x4xi32>) -// CHECK-SAME: permutation = [0, 1, 2, 5, 3, 6, 4] +// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x5x8x4x4x4x2x16xi32>) +// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x2x4x16x4xi32>) +// CHECK-SAME: permutation = [0, 1, 2, 5, 6, 3, 7, 4] // CHECK: flow.dispatch.tensor.store %[[TRANSPOSE]] // ----- @@ -426,11 +426,11 @@ func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-LABEL: func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose -// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x8x4x16x4xi32>) -// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x8x16xi32>) -// CHECK-SAME: permutation = [0, 1, 2, 4, 6, 3, 5] +// CHECK-SAME: ins(%{{.+}} : tensor<2x5x8x4x2x4x16x4xi32>) +// CHECK-SAME: outs({{.*}} : tensor<2x5x8x4x4x4x2x16xi32>) +// CHECK-SAME: permutation = [0, 1, 2, 5, 7, 3, 4, 6] // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] -// CHECK-SAME: : tensor<2x5x8x4x4x8x16xi32> into tensor<2x5x128x128xi32> +// CHECK-SAME: : tensor<2x5x8x4x4x4x2x16xi32> into tensor<2x5x128x128xi32> // CHECK: %[[UNPACK:.*]] = tensor.unpack %[[COLLAPSE]] // CHECK-SAME: outer_dims_perm = [0, 1] // CHECK-SAME: inner_dims_pos = [0, 1] @@ -490,10 +490,10 @@ func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) // CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) // CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor -// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor -// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] -// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp new file mode 100644 index 000000000000..7ae46e6592a1 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp @@ -0,0 +1,46 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/TileSwizzle.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir::iree_compiler { + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + TileSwizzle::Dim::Kind kind) { + switch (kind) { + case TileSwizzle::Dim::Kind::Internal: + return os << "Internal"; + case TileSwizzle::Dim::Kind::CrossThread: + return os << "CrossThread"; + case TileSwizzle::Dim::Kind::CrossIntrinsic: + return os << "CrossIntrinsic"; + } +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim) { + return os << dim.size << "(" << dim.kind << ")"; +} + +static llvm::raw_ostream & +operator<<(llvm::raw_ostream &os, + const TileSwizzle::ExpandShapeDimVectorType &expandShapeDimVector) { + os << "["; + llvm::interleaveComma(expandShapeDimVector, os); + return os << "]"; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const TileSwizzle &swizzle) { + os << "{expandShape = ["; + llvm::interleaveComma(swizzle.expandShape, os); + os << "], swizzle = ["; + llvm::interleaveComma(swizzle.permutation, os); + os << "]}"; + return os; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h index b908ae43fac3..738bb6a43e94 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h +++ b/compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h @@ -9,6 +9,7 @@ #include #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" namespace mlir::iree_compiler { @@ -16,13 +17,50 @@ namespace mlir::iree_compiler { // pair of ops performing a change of layout within the tiles. This is used // on GPU, where the tiles themselves can have an arbitrary layout. struct TileSwizzle { + struct Dim { + // Describes what varies across this dimension. + enum class Kind : int8_t { + // This dimension is internal to one intrinsic on one thread. This + // is only seen for intrinsic operands that are themselves vectors. + // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic, + // the C-matrix operand is a vector of 4 floats already at the level of + // one intrinsic on one thread. That dimension of size 4 is 'Internal'. + Internal, + // This dimension is internal to one intrinsic, but is across threads. + // For example, with AMD MFMA, for the MFMA_F32_16x16x4_F32 intrinsic, + // the A-matrix tile has shape 16x4, and these two dimensions of size 16 + // and 4 are 'CrossThread': neither is visible at the single-thread level + // (in the intrinsic itself, the A-matrix operand is a single scalar) but + // as we move along these dimensions, we are moving over the 64 threads + // of the subgroup. + // + // Another example of cross-thread dimensions is in kernels that are + // "unrolled" across subgroups. Such dimensions are cross-subgroup, so in + // particular they are cross-thread. + CrossThread, + // This dimensions is across intrinsics, as in, actual instructions in the + // generated code. In other words, it is an actual unrolling factor, + // resulting in this many more instructions being generated and executed + // on each thread/subgroup. + CrossIntrinsic + }; + + Kind kind = Kind::Internal; + + // The size of the dimension. + int16_t size = 0; + }; + + using ExpandShapeDimVectorType = llvm::SmallVector; + using ExpandShapeType = llvm::SmallVector; + // This vector-of-vectors contains all the information needed to generate // a `tensor.expand_shape` creating additional internal dimensions into the // tile. For example, expandShape = [[16], [4, 2]] means that the original // tile shape [16, 8] gets expanded such that the first dimension 16 is left // unchanged, and the second dimension 8 gets split into two internal dims // of size 4 and 2. - llvm::SmallVector> expandShape; + ExpandShapeType expandShape; // This permutation vector applies to the expanded dimensions and is used // to generate a `linalg.transpose` changing the layout of the tile. For // example, permutation[0] dictates which of the expanded dimensions becomes @@ -30,6 +68,14 @@ struct TileSwizzle { llvm::SmallVector permutation; }; +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + TileSwizzle::Dim::Kind kind); + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, TileSwizzle::Dim dim); + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const TileSwizzle &swizzle); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_TILESWIZZLE_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 07cd27df23e8..b8b88061a206 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -898,7 +898,8 @@ std::tuple DataTiledMMAAttr::getABCElementTypes() const { std::tuple DataTiledMMAAttr::getMNKShape() const { MLIRContext *ctx = getContext(); auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); - return {opaqueLayout.mSize * getUnrollM(), opaqueLayout.nSize * getUnrollN(), + return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(), + opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(), opaqueLayout.kSize * getUnrollK()}; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index be04d1925a38..d3dc53abced5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -255,9 +255,11 @@ def IREEGPU_DataTiledMMAAttr : let parameters = (ins "::mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr":$intrinsic, - "int64_t":$unroll_m, - "int64_t":$unroll_n, - "int64_t":$unroll_k + DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, on the same thread.">:$unroll_m, + DefaultValuedParameter<"int64_t", "1", "Unrolling along the M dimension, distributed across this many more threads.">:$unroll_m_to_subgroups, + DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, on the same thread.">:$unroll_n, + DefaultValuedParameter<"int64_t", "1", "Unrolling along the N dimension, distributed across this many more threads.">:$unroll_n_to_subgroups, + DefaultValuedParameter<"int64_t", "1", "Unrolling along the K dimension, on the same thread, with interleaved layout.">:$unroll_k ); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir index 046a3c88abbc..0ebe947dec15 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir @@ -29,21 +29,21 @@ module { module { func.func @test_data_tiled_mfma_f32_16x16x4_f32() attributes { - mma_types = #iree_gpu.data_tiled_mma_layout} { + mma_types = #iree_gpu.data_tiled_mma_layout} { return } } // CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x4_f32 -// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout module { func.func @test_data_tiled_mfma_f32_16x16x16_f16() attributes { - mma_types = #iree_gpu.data_tiled_mma_layout} { + mma_types = #iree_gpu.data_tiled_mma_layout} { return } } // CHECK-LABEL: func @test_data_tiled_mfma_f32_16x16x16_f16 -// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout module { func.func @test_data_tiled_mfma_i32_16x16x32_i8() attributes { @@ -52,7 +52,7 @@ module { } } // CHECK-LABEL: func @test_data_tiled_mfma_i32_16x16x32_i8 -// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: mma_types = #iree_gpu.data_tiled_mma_layout module { func.func @test_any_lowering_config() attributes { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir index b174f8f9eb50..0aa922e14850 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_ops.mlir @@ -227,7 +227,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor, %rh %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { indexing_maps = #contraction_accesses, iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], - kind = #iree_gpu.data_tiled_mma_layout + kind = #iree_gpu.data_tiled_mma_layout } : tensor, tensor into tensor return %0 : tensor } @@ -240,7 +240,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor, %rh // CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2 // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] -// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor, tensor into tensor // ----- @@ -270,6 +270,34 @@ func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor, % // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor, tensor into tensor +// ----- + +#contraction_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +func.func @data_tiled_2x2x4_tensor_multi_mma(%lhs: tensor, %rhs: tensor, %acc: tensor) -> tensor { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.data_tiled_mma_layout + } : tensor, tensor into tensor + return %0 : tensor +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma +// CHECK: iree_gpu.multi_mma %arg0, %arg1, %arg2 +// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK-SAME: : tensor, tensor into tensor + + // ----- func.func @tensor_barrier(%input: tensor) -> tensor {