Skip to content

Commit

Permalink
[Codegen][ROCM] Add a codegen pipeline for supported MFMA variants (i…
Browse files Browse the repository at this point in the history
…ree-org#16258)

The codegen pipeline added here, named LLVMGPUVectorDistribute, uses
the vector distribution patterns to do distribution to MFMA types. This
pipeline is still largely work in progress and thus is hidden behind
a flag: --iree-codegen-llvmgpu-use-vector-distribution

However this sets the skeleton for the pipeline. The major additions
here are:

1. The IREEGPUDialect which adds attributes for MFMA types and adds some
   convenience templating/interfaces for defining other kinds of mma
   types (i.e. wmma/mma.sync). This should help normalize many of the
   strategies around these tensor core operations without needing to
   update the code in too many places. Details are often not that nice
   though :)

2. The LLVMGPUVectorDistribute pass pipeline for matmul codegen.
   Currently because vector distribution is a bit tuned for matmul
   cases, this special cases to matmuls, but much of the code still in
   the process of landing should allow for this to be a much more
   general pipeline in the future. For now taking baby steps.

To supplement the pass pipeline, adds GPUVectorAlloc and
LLVMGPUVectorDistribute passes. The vector allocation pass is
somewhat ad-hoc and could be dropped depending on some decisions made
around pipelining, but is needed for now to get promotion. It's
a simple enough pattern to slot in and out depending on how promotion
happens. The vector distribution pass builds off of the recently added
work on that front and implements the logic for setting the anchors for
distribution. That pass is subject to change as layout attributes
evolve.

Things that do not work well in this first iteration:
1. No pipelining
2. Barrier placement is not smart
3. Distribution is limited to a single warp/subgroup
4. Kernel configuration is very basic, however without multi-warp
   distribution it is probably not worth building out the logic there
   just yet.
  • Loading branch information
qedawkins authored Feb 7, 2024
1 parent 9517472 commit 831d240
Show file tree
Hide file tree
Showing 49 changed files with 2,134 additions and 5 deletions.
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@ iree_compiler_cc_library(
name = "ROCM",
srcs = [
"ROCMTarget.cpp",
"ROCMTargetFeatures.cpp",
"ROCMTargetUtils.cpp",
],
hdrs = [
"ROCMTargetFeatures.h",
"ROCMTargetUtils.h",
],
deps = [
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/LLVMGPU",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Dialect/HAL/Target:LLVMLinkerUtils",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"//runtime/src/iree/schemas:rocm_executable_def_c_fbs",
"@llvm-project//llvm:AMDGPUCodeGen",
"@llvm-project//llvm:Analysis",
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/target/ROCM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ iree_cc_library(
NAME
ROCM
HDRS
"ROCMTargetFeatures.h"
"ROCMTargetUtils.h"
SRCS
"ROCMTarget.cpp"
"ROCMTargetFeatures.cpp"
"ROCMTargetUtils.cpp"
DEPS
IREEVectorExtDialect
LLVMAMDGPUCodeGen
LLVMAnalysis
LLVMBitWriter
Expand All @@ -52,6 +55,7 @@ iree_cc_library(
MLIRSupport
MLIRTargetLLVMIRExport
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::LLVMGPU
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::Target
Expand Down
11 changes: 11 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "./ROCMTargetFeatures.h"
#include "./ROCMTargetUtils.h"

#include <cstdint>
#include <mutex>

#include "compiler/plugins/target/ROCM/ROCMTargetFeatures.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/LLVMLinkerUtils.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
Expand Down Expand Up @@ -139,6 +143,8 @@ class ROCMTargetBackend final : public TargetBackend {
mlir::registerLLVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
registry.insert<IREE::Codegen::IREECodegenDialect>();
registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
registry.insert<IREE::GPU::IREEGPUDialect>();
registry.insert<amdgpu::AMDGPUDialect>();
}

Expand Down Expand Up @@ -464,6 +470,11 @@ class ROCMTargetBackend final : public TargetBackend {

addConfig("ukernels", StringAttr::get(context, options.enableROCMUkernels));

ArrayAttr mmaAttrs = getROCMSupportedMmaAttrs(context, options.targetChip);
if (mmaAttrs) {
addConfig("mma_intrinsics", mmaAttrs);
}

auto configAttr = b.getDictionaryAttr(configItems);
return IREE::HAL::ExecutableTargetAttr::get(
context, b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"),
Expand Down
32 changes: 32 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "./ROCMTargetFeatures.h"

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "llvm/ADT/StringSwitch.h"

namespace mlir::iree_compiler::IREE::HAL {

static ArrayAttr getMfmaArrayAttr(MLIRContext *context,
ArrayRef<IREE::GPU::MFMAIntrinsic> types) {
SmallVector<Attribute> attrs(types.size(), IREE::GPU::MFMAAttr());
for (auto [idx, type] : llvm::enumerate(types)) {
attrs[idx] = IREE::GPU::MFMAAttr::get(context, type);
}
return ArrayAttr::get(context, attrs);
}

ArrayAttr getROCMSupportedMmaAttrs(MLIRContext *context, StringRef targetArch) {
if (targetArch == "gfx940") {
return getMfmaArrayAttr(context,
{IREE::GPU::MFMAIntrinsic::F16_16x16x16_F32,
IREE::GPU::MFMAIntrinsic::F16_32x32x8_F32});
}
return ArrayAttr();
}

} // namespace mlir::iree_compiler::IREE::HAL
21 changes: 21 additions & 0 deletions compiler/plugins/target/ROCM/ROCMTargetFeatures.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_PLUGINS_TARGET_ROCM_ROCMTARGETFEATURES_H_
#define IREE_COMPILER_PLUGINS_TARGET_ROCM_ROCMTARGETFEATURES_H_

#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "llvm/IR/Module.h"
#include "llvm/Target/TargetMachine.h"

namespace mlir::iree_compiler::IREE::HAL {

// Returns the list of supported mma types (mfma/wmma).
ArrayAttr getROCMSupportedMmaAttrs(MLIRContext *context, StringRef targetArch);

} // namespace mlir::iree_compiler::IREE::HAL

#endif // IREE_COMPILER_PLUGINS_TARGET_ROCM_ROCMTARGETFEATURES_H_
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:NVGPUDialect",
Expand All @@ -62,6 +63,7 @@ iree_compiler_cc_library(
"GPUTensorTile.cpp",
"GPUTensorTileToSerialLoops.cpp",
"GPUTileReduction.cpp",
"GPUVectorAlloc.cpp",
"GPUVectorDistribution.cpp",
"Passes.cpp",
"VectorReductionToGPU.cpp",
Expand Down Expand Up @@ -91,6 +93,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:DestinationStyleOpInterface",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:GPUDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ iree_cc_library(
"GPUTensorTile.cpp"
"GPUTensorTileToSerialLoops.cpp"
"GPUTileReduction.cpp"
"GPUVectorAlloc.cpp"
"GPUVectorDistribution.cpp"
"Passes.cpp"
"VectorReductionToGPU.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ struct DistributeTransferWriteLayoutAttr final
return failure();
}

// TODO: Return failure if we need masking.
if (writeOp.getMask()) {
return failure();
}

accessMemory(writeOp, writeOp.getVector(), vectorLayout, rewriter);

Expand Down
153 changes: 153 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/LinalgOpInfo.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h"

#define DEBUG_TYPE "iree-codegen-gpu-vector-alloc"

namespace mlir::iree_compiler {

// For optimal performance we always want to copy 128 bits.
static constexpr int copyVectorNumBits = 128;

/// Filter to decide which contraction ops need allocations.
static bool contractOpFilter(Operation *op) {
auto contractOp = dyn_cast<vector::ContractionOp>(op);
if (!contractOp) {
return false;
}
SmallVector<unsigned> dims;
for (auto [idx, type] : llvm::enumerate(contractOp.getIteratorTypesArray())) {
if (type == vector::IteratorType::parallel) {
dims.push_back(idx);
}
}
SmallVector<int64_t> shapes;
contractOp.getIterationBounds(shapes);
// Don't promote vector*matrix kind of case.
int numNonUnitParallelLoop = 0;
for (unsigned parallelDim : dims) {
if (shapes[parallelDim] != 1) {
numNonUnitParallelLoop++;
}
}
// TODO: Relax this constraint.
return numNonUnitParallelLoop > 1 && dims.size() >= 2 && dims.size() <= 3;
}

// Allocates a tensor to copy the vector into a la bufferization.alloc_tensor.
// This allocation is always static as vectors are currently always static
// where this is used.
static FailureOr<Value> allocateTensorForVector(OpBuilder &b, Location loc,
Value vector) {
VectorType vectorType = llvm::cast<VectorType>(vector.getType());
if (vectorType.isScalable()) {
return failure();
}

Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get(
b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());

RankedTensorType tensorType =
RankedTensorType::get(vectorType.getShape(), vectorType.getElementType(),
sharedMemoryAddrSpace);
// Vectors are always statically shaped.
auto allocTensorOp = b.create<bufferization::AllocTensorOp>(
loc, tensorType, ValueRange{}, Value());
allocTensorOp.setMemorySpaceAttr(sharedMemoryAddrSpace);

Value c0 = b.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(vectorType.getRank(), c0);
SmallVector<bool> inBounds(vectorType.getRank(), true);
Value copied = b.create<vector::TransferWriteOp>(loc, vector, allocTensorOp,
indices, inBounds)
.getResult();
// Create a marker for bufferization to keep this tensor in place. This
// prevents read/write forwarding of the transfers used to do the copy.
return b
.create<bufferization::MaterializeInDestinationOp>(copied.getLoc(),
copied, copied)
->getResult(0);
}

static Value readVectorFromTensor(OpBuilder &b, VectorType vectorType,
Value tensor) {
Value c0 = b.create<arith::ConstantIndexOp>(tensor.getLoc(), 0);
SmallVector<Value> indices(vectorType.getRank(), c0);
SmallVector<bool> inBounds(vectorType.getRank(), true);
return b
.create<vector::TransferReadOp>(tensor.getLoc(), vectorType, tensor,
indices, inBounds)
.getResult();
}

namespace {

struct GPUVectorAllocPass : public GPUVectorAllocBase<GPUVectorAllocPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect>();
registry.insert<gpu::GPUDialect>();
}
void runOnOperation() override {
auto funcOp = getOperation();

SmallVector<vector::ContractionOp> opsToPromote;
funcOp.walk([&](vector::ContractionOp op) {
// Today we only do promotion for certain contractions.
if (contractOpFilter(op))
opsToPromote.push_back(op);
});
for (vector::ContractionOp contractOp : opsToPromote) {
OpBuilder builder(contractOp);
// Promote both of the input operands, excluding the accumulator.
OpOperand &lhs = contractOp.getLhsMutable();
FailureOr<Value> lhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), lhs.get());
if (failed(lhsRet)) {
return signalPassFailure();
}

OpOperand &rhs = contractOp.getRhsMutable();
FailureOr<Value> rhsRet =
allocateTensorForVector(builder, contractOp->getLoc(), rhs.get());
if (failed(rhsRet)) {
return signalPassFailure();
}

// HACK: Until proper barrier placement is handled later we have to
// synchronize here.
builder.create<gpu::BarrierOp>(contractOp->getLoc());

Value lhsVec =
readVectorFromTensor(builder, contractOp.getLhsType(), *lhsRet);
Value rhsVec =
readVectorFromTensor(builder, contractOp.getRhsType(), *rhsRet);
lhs.set(lhsVec);
rhs.set(rhsVec);
}
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUVectorAlloc() {
return std::make_unique<GPUVectorAllocPass>();
}

} // namespace mlir::iree_compiler
9 changes: 7 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ createGPUPipeliningPass(bool epiloguePeeling = true, unsigned depth = 1,
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUReduceSharedMemoryBankConflicts(int64_t paddingSizeBits = 128);

// Creates a pass to tile reduction dimensions and create allocations for some
// tensor values to use GPU shared memory.
// Creates a pass to create allocations for some tensor values to use GPU
// shared memory.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUTensorAlloc(GPUPromoteSharedMemPattern promoteSharedMemPattern =
GPUPromoteSharedMemPattern::ContractionOpPattern);
Expand All @@ -118,6 +118,11 @@ createGPUTensorTileToSerialLoops();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUTileReductionPass();

// Creates a pass to create allocations for some vector values to use GPU
// shared memory.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUVectorAlloc();

// Distributes vector ops to all threads/warps in a GPU workgroup.
// `getWarpSize` is for deciding the warp size to use; it takes the
// current function containing those vector ops as the argument.
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def GPUTileReduction :
let constructor = "mlir::iree_compiler::createGPUTileReductionPass()";
}

def GPUVectorAlloc :
InterfacePass<"iree-codegen-gpu-vector-alloc", "mlir::FunctionOpInterface"> {
let summary = "Pass to create allocations for contraction inputs to copy "
"to GPU shared memory";
let constructor = "mlir::iree_compiler::createGPUVectorAlloc()";
}

def VectorReductionToGPU :
InterfacePass<"iree-codegen-vector-reduction-to-gpu", "mlir::FunctionOpInterface"> {
let summary = "Convert vector reduction to GPU ops.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_lit_test_suite(
"gpu_tensor_tile.mlir",
"gpu_workgroup_swizzle.mlir",
"gpu_tile_reduction.mlir",
"gpu_vector_alloc.mlir",
"gpu_vector_distribution.mlir",
"reduce_bank_conflicts.mlir",
"transform_gpu_distribute_shared_memory.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"gpu_tensor_alloc.mlir"
"gpu_tensor_tile.mlir"
"gpu_tile_reduction.mlir"
"gpu_vector_alloc.mlir"
"gpu_vector_distribution.mlir"
"gpu_workgroup_swizzle.mlir"
"reduce_bank_conflicts.mlir"
Expand Down
Loading

0 comments on commit 831d240

Please sign in to comment.