Skip to content

Commit

Permalink
[GPU] Adding support for opt pass plugins during AMDGPU executable se…
Browse files Browse the repository at this point in the history
…rialization (iree-org#18347)

This commit adds the --iree-hal-target-pass-plugins flag That allows to
add plugins during executable code generation and serialization.

This is interesting for adding instrumentation via external passes
(e.g., https://github.com/CRobeck/instrument-amdgpu-kernels)

I am creating this for two reasons. First, to see if there is interest
for this. Second, to get help on an error I have. Currently, I am having
some issues with my tests where there is a segfault during dlopen. If
anyone has some clue what may be happening, that would be awesome.

---------

Signed-off-by: Jose M Monsalve Diaz <[email protected]>
  • Loading branch information
josemonsalve2 authored Oct 16, 2024
1 parent 206b60c commit 8568efa
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 2 deletions.
30 changes: 28 additions & 2 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -61,6 +62,11 @@ struct ROCmOptions {
std::string enableROCMUkernels = "none";
bool legacySync = true;

/// List of LLVM opt pass pluggins to be loaded during GPU code
/// generation. The pluggins are paths to dynamic libraries that
/// are added to the LLVM pass manager.
SmallVector<std::string> passPlugins;

void bindOptions(OptionsBinder &binder) {
using namespace llvm;
static cl::OptionCategory category("HIP HAL Target");
Expand Down Expand Up @@ -95,6 +101,13 @@ struct ROCmOptions {
binder.opt<bool>("iree-hip-legacy-sync", legacySync, cl::cat(category),
cl::desc("Enables 'legacy-sync' mode, which is required "
"for inline execution."));
binder.list<std::string>(
"iree-hip-pass-plugin-path", passPlugins,
cl::desc("LLVM pass plugins are out of tree libraries that implement "
"LLVM opt passes. The library paths passed in this flag are "
"to be passed to the target backend compiler during HIP "
"executable serialization"),
cl::ZeroOrMore, cl::cat(category));
}

LogicalResult verify(mlir::Builder &builder) const {
Expand Down Expand Up @@ -272,7 +285,8 @@ class ROCMTargetBackend final : public TargetBackend {
// ones). Inspired by code section in
// https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
static void optimizeModule(llvm::Module &module,
llvm::TargetMachine &targetMachine) {
llvm::TargetMachine &targetMachine,
ArrayRef<std::string> passPlugins) {
llvm::LoopAnalysisManager lam;
llvm::FunctionAnalysisManager fam;
llvm::CGSCCAnalysisManager cgam;
Expand All @@ -296,6 +310,18 @@ class ROCMTargetBackend final : public TargetBackend {
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);

for (const std::string &pluginFileName : passPlugins) {
llvm::Expected<llvm::PassPlugin> pp =
llvm::PassPlugin::Load(pluginFileName);
if (pp) {
pp->registerPassBuilderCallbacks(pb);
} else {
std::string error = "unable to load plugin " + pluginFileName + ": " +
llvm::toString(pp.takeError());
llvm::report_fatal_error(error.c_str());
}
}

llvm::OptimizationLevel ol = llvm::OptimizationLevel::O2;

mpm.addPass(llvm::VerifierPass());
Expand Down Expand Up @@ -522,7 +548,7 @@ class ROCMTargetBackend final : public TargetBackend {
}

// Run LLVM optimization passes.
optimizeModule(*llvmModule, *targetMachine);
optimizeModule(*llvmModule, *targetMachine, options.passPlugins);
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpModuleToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
Expand Down
28 changes: 28 additions & 0 deletions compiler/plugins/target/ROCM/test/opt_pass_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
iree_cc_library(
NAME
GPUHello
SRCS
"GPUHello.cpp"
DEPS
iree::compiler::API::Impl
SHARED
)

# NOTE: this is only required because we want this sample to run on all
# platforms without needing to change the library name (libfoo.so/foo.dll).
set_target_properties(iree_compiler_plugins_target_ROCM_test_opt_pass_plugin_GPUHello
PROPERTIES
WINDOWS_EXPORT_ALL_SYMBOLS ON
PREFIX "lib"
OUTPUT_NAME "GPUHello"
)

iree_lit_test_suite(
NAME
lit
SRCS
"gpu_hello.mlir"
TOOLS
FileCheck
iree-opt
)
82 changes: 82 additions & 0 deletions compiler/plugins/target/ROCM/test/opt_pass_plugin/GPUHello.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

namespace {

struct GpuHello final : llvm::PassInfoMixin<GpuHello> {
llvm::PreservedAnalyses run(llvm::Module &module,
llvm::ModuleAnalysisManager &) {
bool modifiedCodeGen = runOnModule(module);
if (!modifiedCodeGen)
return llvm::PreservedAnalyses::none();

return llvm::PreservedAnalyses::all();
}

bool runOnModule(llvm::Module &module);
// We set `isRequired` to true to keep this pass from being skipped
// if it has the optnone LLVM attribute.
static bool isRequired() { return true; }
};

bool GpuHello::runOnModule(llvm::Module &module) {
for (llvm::Function &function : module) {
if (function.isIntrinsic() || function.isDeclaration())
continue;

if (function.getCallingConv() != llvm::CallingConv::AMDGPU_KERNEL &&
function.getCallingConv() != llvm::CallingConv::PTX_Kernel)
continue;

for (llvm::BasicBlock &basicBlock : function) {
for (llvm::Instruction &inst : basicBlock) {
llvm::DILocation *debugLocation = inst.getDebugLoc();
std::string sourceInfo;
if (!debugLocation) {
sourceInfo = function.getName().str();
} else {
sourceInfo = llvm::formatv("{0}\t{1}:{2}:{3}", function.getName(),
debugLocation->getFilename(),
debugLocation->getLine(),
debugLocation->getColumn())
.str();
}

llvm::errs() << "Hello From First Instruction of GPU Kernel: "
<< sourceInfo << "\n";
return false;
}
}
}
return false;
}

} // end anonymous namespace

llvm::PassPluginLibraryInfo getPassPluginInfo() {
const auto callback = [](llvm::PassBuilder &pb) {
pb.registerOptimizerLastEPCallback([&](llvm::ModulePassManager &mpm, auto) {
mpm.addPass(GpuHello());
return true;
});
};
return {LLVM_PLUGIN_API_VERSION, "gpu-hello", LLVM_VERSION_STRING, callback};
};

extern "C" LLVM_ATTRIBUTE_WEAK
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT ::llvm::PassPluginLibraryInfo
llvmGetPassPluginInfo() {
return getPassPluginInfo();
}
39 changes: 39 additions & 0 deletions compiler/plugins/target/ROCM/test/opt_pass_plugin/gpu_hello.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=gfx90a \
// RUN: --iree-hip-pass-plugin-path=$IREE_BINARY_DIR/lib/libGPUHello$IREE_DYLIB_EXT %s 2>&1 | FileCheck %s

module attributes {
hal.device.targets = [
#hal.device.target<"hip", [
#hal.executable.target<"rocm", "rocm-hsaco-fb">
]> : !hal.device
]
} {

stream.executable public @add_dispatch_0 {
stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
%0 = tensor.empty() : tensor<16xf32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
return
}
}
}

}

// CHECK: Hello From First Instruction of GPU Kernel: add_dispatch_0

0 comments on commit 8568efa

Please sign in to comment.