Skip to content

Commit

Permalink
[metal] Enable compiling to Metal library when possible
Browse files Browse the repository at this point in the history
This commit extends Metal compilation to additionally invoke
Metal shader compilers on macOS to further compile MSL into
Metal library, so we don't need to JIT compile during runtime.
This reduces runtime overhead. For other platforms where we
don't have access to offline Metal compilers, we still embed
MSL source code in IREE flatbuffer.
  • Loading branch information
antiagainst committed Jun 14, 2023
1 parent 6e93f00 commit 59e67a7
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
srcs = ["MetalSPIRVTarget.cpp"],
hdrs = ["MetalSPIRVTarget.h"],
deps = [
":MSLToMetalLib",
":SPIRVToMSL",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
Expand All @@ -33,6 +34,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Utils",
"//runtime/src/iree/schemas:metal_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LinalgDialect",
Expand All @@ -54,3 +57,15 @@ iree_compiler_cc_library(
"@spirv_cross//:spirv_cross_lib",
],
)

iree_compiler_cc_library(
name = "MSLToMetalLib",
srcs = [
"MSLToMetalLib.cpp",
],
hdrs = ["MSLToMetalLib.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ iree_cc_library(
SRCS
"MetalSPIRVTarget.cpp"
DEPS
::MSLToMetalLib
::SPIRVToMSL
LLVMSupport
LLVMTargetParser
MLIRAffineDialect
MLIRGPUDialect
MLIRLinalgDialect
Expand Down Expand Up @@ -53,4 +56,17 @@ iree_cc_library(
PUBLIC
)

iree_cc_library(
NAME
MSLToMetalLib
HDRS
"MSLToMetalLib.h"
SRCS
"MSLToMetalLib.cpp"
DEPS
LLVMSupport
MLIRSupport
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MSLToMetalLib.h"

#include <stdlib.h>

#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/LogicalResult.h"

#define DEBUG_TYPE "iree-msl-to-metal-lib"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {

/// Returns the command to compile the given MSL source file into Metal library.
static std::string getMetalCompileCommand(StringRef mslFile,
StringRef libFile) {
return llvm::Twine("xcrun -sdk macosx metal -c ")
.concat(mslFile)
.concat(" -o - | xcrun -sdk macosx metallib - -o ")
.concat(libFile)
.str();
}

/// Returns the given command via system shell.
static LogicalResult runSystemCommand(StringRef command) {
LLVM_DEBUG(llvm::dbgs() << "Running system command: '" << command << "'\n");
int exitCode = system(command.data());
if (exitCode == 0) return success();
llvm::errs() << "Failed to run system command '" << command
<< "' with error code: " << exitCode << "\n";
return failure();
}

std::unique_ptr<llvm::MemoryBuffer> compileMSLToMetalLib(StringRef mslCode,
StringRef entryPoint) {
SmallString<32> mslFile, airFile, libFile;
int mslFd = 0;
llvm::sys::fs::createTemporaryFile(entryPoint, "metal", mslFd, mslFile);
llvm::sys::fs::createTemporaryFile(entryPoint, "metallib", libFile);
llvm::FileRemover mslRemover(mslFile.c_str());
llvm::FileRemover libRemover(libFile.c_str());

{ // Write input MSL code to the temporary file.
llvm::raw_fd_ostream inputStream(mslFd, /*shouldClose=*/true);
inputStream << mslCode << "\n";
}

std::string command = getMetalCompileCommand(mslFile, libFile);
if (failed(runSystemCommand(command))) return nullptr;

auto fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(libFile, /*isText=*/false);
if (std::error_code error = fileOrErr.getError()) {
llvm::errs() << "Failed to open generated metallib file '" << libFile
<< "' with error: " << error.message();
return nullptr;
}

return std::move(*fileOrErr);
}

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_METALSPIRV_MSLTOMETALLIB_H_
#define IREE_COMPILER_DIALECT_HAL_TARGET_METALSPIRV_MSLTOMETALLIB_H_

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MemoryBuffer.h"

namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {

// Invokes system commands to compile the given |mslCode| into a Metal library
// and returns the library binary code. |fileName| will be used as a hint for
// creating intermediate files.
std::unique_ptr<llvm::MemoryBuffer> compileMSLToMetalLib(
llvm::StringRef mslCode, llvm::StringRef fileName);

} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir

#endif // IREE_COMPILER_DIALECT_HAL_TARGET_METALSPIRV_MSLTOMETALLIB_H_
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/SPIRVPasses.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MSLToMetalLib.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/schemas/metal_executable_def_builder.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -27,6 +31,13 @@ namespace iree_compiler {
namespace IREE {
namespace HAL {

static llvm::cl::opt<bool> clCompileToMetalLib(
"iree-metal-compile-to-metallib",
llvm::cl::desc(
"Compile to .metallib and embed in IREE deployable flatbuffer if true; "
"otherwise stop at and embed MSL source code"),
llvm::cl::init(true));

static spirv::TargetEnvAttr getMetalTargetEnv(MLIRContext *context) {
using spirv::Capability;
using spirv::Extension;
Expand Down Expand Up @@ -169,29 +180,43 @@ class MetalSPIRVTargetBackend : public TargetBackend {
mslEntryPointNames.push_back(std::move(msl->second));
}

// 3. Compile MSL to MTLLibrary.
// TODO(antiagainst): provide the option to compile the shaders into a
// library and embed in the FlatBuffer. Metal provides APIs for compiling
// shader sources into a MTLLibrary at run-time, but does not provie
// a way to serialize the generated MTLLibrary. The only way available is
// to use command-line tools like `metal` and `metallib`. Likely we need
// to invoke them in C++.

if (!options.dumpBinariesPath.empty()) {
for (auto shader : llvm::enumerate(mslShaders)) {
dumpDataToPath(
options.dumpBinariesPath, options.dumpBaseName,
(variantOp.getName() + std::to_string(shader.index())).str(),
".msl", shader.value().source);
".metal", shader.value().source);
}
}

// 3. Compile MSL to MTLLibrary.
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> metalLibs;
if (clCompileToMetalLib) {
// We need to use offline Metal shader compilers.
// TODO(#14048): The toolchain can also exist on other platforms. Probe
// the PATH instead.
auto hostTriple = llvm::Triple(llvm::sys::getProcessTriple());
if (hostTriple.isMacOSX()) {
for (auto [shader, entryPoint] :
llvm::zip(mslShaders, mslEntryPointNames)) {
std::unique_ptr<llvm::MemoryBuffer> lib =
compileMSLToMetalLib(shader.source, entryPoint);
if (!lib) {
return variantOp.emitError()
<< "failed to compile to MTLLibrary from MSL:\n\n"
<< shader.source << "\n\n";
}
metalLibs.push_back(std::move(lib));
}
}
}

// 4. Pack the MTLLibrary and metadata into a FlatBuffer.
FlatbufferBuilder builder;
iree_hal_metal_ExecutableDef_start_as_root(builder);

auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
mslShaders, [&](const MetalShader &shader) { return shader.source; }));
auto entryPointNamesRef = builder.createStringVec(mslEntryPointNames);
iree_hal_metal_ExecutableDef_entry_points_add(builder, entryPointNamesRef);

iree_hal_metal_ThreadgroupSize_vec_start(builder);
for (auto &shader : mslShaders) {
Expand All @@ -200,13 +225,26 @@ class MetalSPIRVTargetBackend : public TargetBackend {
shader.threadgroupSize.z);
}
auto threadgroupSizesRef = iree_hal_metal_ThreadgroupSize_vec_end(builder);

auto entryPointNamesRef = builder.createStringVec(mslEntryPointNames);

iree_hal_metal_ExecutableDef_entry_points_add(builder, entryPointNamesRef);
iree_hal_metal_ExecutableDef_threadgroup_sizes_add(builder,
threadgroupSizesRef);
iree_hal_metal_ExecutableDef_shader_sources_add(builder, shaderSourcesRef);

if (metalLibs.empty()) {
auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
mslShaders,
[&](const MetalShader &shader) { return shader.source; }));
iree_hal_metal_ExecutableDef_shader_sources_add(builder,
shaderSourcesRef);
} else {
auto refs = llvm::to_vector<8>(llvm::map_range(
metalLibs, [&](const std::unique_ptr<llvm::MemoryBuffer> &buffer) {
return flatbuffers_string_create(builder, buffer->getBufferStart(),
buffer->getBufferSize());
}));
auto libsRef =
flatbuffers_string_vec_create(builder, refs.data(), refs.size());
iree_hal_metal_ExecutableDef_shader_libraries_add(builder, libsRef);
}

iree_hal_metal_ExecutableDef_end_as_root(builder);

// 5. Add the binary data to the target executable.
Expand Down
Loading

0 comments on commit 59e67a7

Please sign in to comment.