Skip to content

Commit

Permalink
Add a foreach_thread_to_workgroup transform (iree-org#10131)
Browse files Browse the repository at this point in the history
This revision adds a transform dialect op that is very similar to `foreach_thread_to_gpu` to allow distributing to the first level of parallelism during codegen.

This is subject to conventions of what values are captured in the workgroup_count region and how these values are passed around in the `stream.cmd.dispatch` op.

For now the convention choosen is that `0` values are passed and the `tile_to_foreach_thread` may only use static sizes.

A followup PR will introduce a new transform dialect op to allow creating the sizes dynamically based on other conventions.
  • Loading branch information
nicolasvasilache authored Aug 31, 2022
1 parent 726c592 commit ff04220
Show file tree
Hide file tree
Showing 12 changed files with 430 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ iree_compiler_cc_library(
":CommonExtensionsOpGen",
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREEDialectsTransforms",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ iree_cc_library(
IREEDialectsTransforms
IREELinalgTransformDialect
LLVMSupport
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRPass
MLIRTransformDialect
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
iree::compiler::Codegen::PassHeaders
iree::compiler::Dialect::HAL::IR
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

#include "CommonExtensions.h"

#include <iree/compiler/Dialect/HAL/IR/HALOps.h>

#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Passes.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Pass/PassManager.h"

using namespace mlir;
Expand Down Expand Up @@ -76,6 +80,9 @@ DiagnosedSilenceableFailure transform_dialect::ApplyPatternsOp::applyToOne(
// IREEBufferizeOp
//===---------------------------------------------------------------------===//

// TODO: Maybe we need both a transform.iree.cpu.bufferize and a
// transform.iree.gpu.bufferize rather than a single common bufferize op?
//
//===---------------------------------------------------------------------===//
// Default allocation functions for CPU backend
// TODO: register the bufferization behavior in a target-specific way.
Expand Down Expand Up @@ -172,6 +179,229 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
});
return DiagnosedSilenceableFailure(failure(res.wasInterrupted()));
}
/// Populate the workgroup_count region of `dispatchOp`.
/// For now, this only supports constant index ops and empty workload operands.
/// Assumes the HAL::ExecutableExportOp is built with an empty region.
static LogicalResult populateWorkgroupCountComputingRegion(
PatternRewriter &rewriter, scf::ForeachThreadOp foreachThreadOp,
HAL::ExecutableExportOp exportOp) {
Location loc = foreachThreadOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
Region &r = exportOp.getWorkgroupCount();
assert(r.empty() && "expected block-less workgroup_count region");
Block *block = rewriter.createBlock(&r);
// The HAL::DeviceType argument is always the first argument.
block->addArgument(HAL::DeviceType::get(rewriter.getContext()), loc);
rewriter.setInsertionPointToStart(block);

SmallVector<Value> results;
// For now, this assumes that we only pull in constants.
// TODO: Iteratively pull required operations.
for (Value v : foreachThreadOp.getNumThreads()) {
auto op = dyn_cast_or_null<arith::ConstantIndexOp>(v.getDefiningOp());
if (!op) return failure();
results.push_back(
cast<arith::ConstantIndexOp>(rewriter.clone(*op)).getResult());
}
// Pad to `3` to match assumptions hardcoded in IREE.
for (unsigned i = results.size(); i < 3; ++i) {
results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 1));
}
rewriter.create<HAL::ReturnOp>(loc, results);

return success();
}

/// Apply the permutation `perm` to `vals.
/// Return failure if perm is not a permutation.
// TODO: upstream as extraClassDeclaration once stabilized.
template <typename T>
static FailureOr<SmallVector<T>> permute(const SmallVector<T> &vals,
ArrayRef<int64_t> perm) {
if (vals.size() != perm.size()) return failure();
SmallVector<T> result(vals.size());
SmallVector<bool> seen(vals.size());
for (const auto &it : llvm::zip(perm, vals)) {
// Already seen, invalid thread_dim_mapping.
if (seen[std::get<0>(it)]) return failure();
result[std::get<0>(it)] = std::get<1>(it);
seen[std::get<0>(it)] = true;
}
// Some not seen, invalid thread_dim_mapping.
if (!llvm::all_of(seen, [](bool b) { return b; })) return failure();
return result;
}

/// Helper to get apply the `thread_dim_mapping` permutation of a
/// `foreachThreadOp` to `values`.
// TODO: upstream as extraClassDeclaration once stabilized.
template <typename T>
static FailureOr<SmallVector<T>> getPermuted(
scf::ForeachThreadOp foreachThreadOp, const SmallVector<T> &values) {
// Apply mapping permutation if specified.
auto mapping = foreachThreadOp.getThreadDimMapping();
if (mapping && !mapping.empty()) {
auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping));
if (failed(maybePermuted))
return foreachThreadOp->emitError("invalid permutation");
return *maybePermuted;
}
return values;
}

/// Helper to get the `num_threads` of a `foreachThreadOp` after applying the
/// `thread_dim_mapping` permutation.
// TODO: upstream as extraClassDeclaration once stabilized.
static FailureOr<SmallVector<OpFoldResult>> getNumThreads(
OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) {
SmallVector<OpFoldResult> threadCount = foreachThreadOp.getNumThreads();
threadCount.resize(3, b.getIndexAttr(1));
return getPermuted(foreachThreadOp, threadCount);
}

/// Helper to get the thread indices of a `foreachThreadOp` after applying the
/// `thread_dim_mapping` permutation.
// TODO: upstream as extraClassDeclaration once stabilized.
static FailureOr<SmallVector<Value>> getThreadIndices(
OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) {
SmallVector<Value> threadCount = foreachThreadOp.getThreadIndices();
threadCount.resize(3, Value());
return getPermuted(foreachThreadOp, threadCount);
}

//===---------------------------------------------------------------------===//
// Patterns for ForeachThreadToWorkgroup rewrite.
//===---------------------------------------------------------------------===//

LogicalResult rewriteForeachThreadToWorkgroup(
scf::ForeachThreadOp foreachThreadOp,
IREE::HAL::ExecutableExportOp exportOp, PatternRewriter &rewriter) {
if (foreachThreadOp.getNumResults() > 0)
return foreachThreadOp->emitError(
"only bufferized scf.foreach_thread lowers to workgroup");
if (foreachThreadOp.getNumThreads().size() > 3)
return foreachThreadOp->emitError(
"scf.foreach_thread with rank > 3 does not lower to workgroup");

// Step 0. Outline the compute workload region and set up the workload
// operands.
auto maybeWorkgroupCounts = getNumThreads(rewriter, foreachThreadOp);
if (failed(maybeWorkgroupCounts) ||
llvm::any_of(*maybeWorkgroupCounts, [](OpFoldResult ofr) {
return !getConstantIntValue(ofr).has_value();
}))
return foreachThreadOp->emitError(
"unsupported dynamic workgroup_count atm --- need to slice out "
"workgroup_count computation into ExecutableExport::workgroup_count. "
"This region may require arbitrary computations and cannot magically "
"match what the `stream.cmd.dispatch` has already imposed on us at a "
"distance. For now we must specify the number of values properly when "
"applying the topLevel tile_to_foreach_thread_op");

SmallVector<int64_t> workgroupCounts;
for (OpFoldResult ofr : *maybeWorkgroupCounts)
workgroupCounts.push_back(getConstantIntValue(ofr).value());
if (failed(populateWorkgroupCountComputingRegion(rewriter, foreachThreadOp,
exportOp)))
return foreachThreadOp->emitOpError(
"failed to populate workload region for dispatchOp: ")
<< exportOp;

// Step 1. Create the workgroup id and count ops.
Location loc = foreachThreadOp.getLoc();
BlockAndValueMapping bvm;
SmallVector<Value, 8> workgroupIdOps, workgroupCountOps;
for (int64_t rank :
llvm::seq<int64_t>(0, foreachThreadOp.getThreadIndices().size())) {
workgroupIdOps.push_back(
rewriter.create<HAL::InterfaceWorkgroupIDOp>(loc, rank));
workgroupCountOps.push_back(
rewriter.create<HAL::InterfaceWorkgroupCountOp>(loc, rank));
}
bvm.map(foreachThreadOp.getThreadIndices(), workgroupIdOps);
bvm.map(foreachThreadOp.getNumThreads(), workgroupCountOps);

// Step 2. Predicate omitted given unique topLevel scf::ForeachThreadOp.

// Step 3. Move the body of foreachThreadOp.
// Erase the terminator first, it will not be used since we are on buffers.
rewriter.eraseOp(foreachThreadOp.getTerminator());
Block *targetBlock;
Block::iterator insertionPoint;
targetBlock = foreachThreadOp->getBlock();
insertionPoint = Block::iterator(foreachThreadOp);
Block &sourceBlock = foreachThreadOp.getRegion().front();
targetBlock->getOperations().splice(insertionPoint,
sourceBlock.getOperations());

// Step 4. RAUW thread indices to thread ops.
SmallVector<Value> threadIndices =
*getThreadIndices(rewriter, foreachThreadOp);
for (auto it : llvm::zip(threadIndices, workgroupIdOps)) {
if (!std::get<0>(it)) continue;
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
}

// Step 5. Barriers omitted given unique topLevel scf::ForeachThreadOp.

// Step 6. Erase old op.
rewriter.eraseOp(foreachThreadOp);

return success();
}

//===---------------------------------------------------------------------===//
// IREE-specific transformations defined outside of iree_linalg_transform.
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::ForeachThreadToWorkgroupOp::applyToOne(
func::FuncOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel to "
"attach the workgroup size information to a nested ExecutableExportOp");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}

IREE::HAL::ExecutableExportOp exportOp;
state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) {
if (op.getSymName() == target.getName()) exportOp = op;
});
if (!exportOp) {
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}
if (!exportOp.getWorkgroupCount().empty())
return emitDefaultSilenceableFailure(target)
<< "export op must have an empty workgroup count region that the "
"transform fills --- the transform is not applied";

scf::ForeachThreadOp topLevelForeachThreadOp;
auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
return WalkResult::advance();
if (topLevelForeachThreadOp) return WalkResult::interrupt();
topLevelForeachThreadOp = foreachThreadOp;
return WalkResult::advance();
});

if (walkResult.wasInterrupted()) {
state.getTopLevel()->emitOpError(
"could not find a unique topLevel scf.foreach_thread");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}

SimplePatternRewriter rewriter(topLevelForeachThreadOp);
if (failed(rewriteForeachThreadToWorkgroup(topLevelForeachThreadOp, exportOp,
rewriter)))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));

results.assign({target});

return DiagnosedSilenceableFailure(success());
}

#define GET_OP_CLASSES
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.cpp.inc"
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,22 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

namespace mlir {
class DialectRegistry;

namespace func {
class FuncOp;
} // namespace func

namespace scf {
class ForeachThreadOp;
} // namespace scf
} // namespace mlir

#define GET_OP_CLASSES
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace iree_compiler {

/// Registers common transformations that require IREE-specific information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,59 @@ def IREEBufferizeOp : Op<Transform_Dialect, "iree.bufferize",
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def ForeachThreadToWorkgroupOp : Op<Transform_Dialect,
"iree.foreach_thread_to_workgroup",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait]> {
let description = [{
Target the whole hal.executable_variant op and rewrite the unique topLevel
scf.foreach_thread to distributed workgroup_id and workgroup_count.

The mapping of threads to workgroup_id is currently one-to-one and in order.
Only **bufferized** scf.foreach_thread are currently supported.
Only scf.foreach_thread distributed to **at most 3 dimensions** are currently
supported.

Return modes:
=============
This operation ignores non-Func ops and drops them in the return.

If no unique scf.foreach_thread topLevel operation is found, then the
transform definitely fails.
If the unique topLevel scf.foreach_thread has results (i.e. tensors), then
the transform definitely fails.

If the unique topLevel scf.foreach_thread maps to a dynamic number of
threads, then the transform definitely fails. This is a temporary
limitation until the backward slice computing scf.foreach_thread.num_threads
can be extracted into the hal::executable_export workgroup_count region.
This region may require arbitrary computations and cannot magically match
what the `stream.cmd.dispatch` has already imposed on us at a distance.
For now we must specify the number of values properly when applying the
topLevel tile_to_foreach_thread_op.

If the unique topLevel scf.foreach_thread operation contained within the
FuncOp referred to by the `target` PDLOperation lowers to workgroup properly,
the transform succeeds. Otherwise the transform definitely fails.

The returned handle points to the same FuncOp operand, consuming it and
producing a new SSA value to satisfy chaining and linearity of the IR
properties.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$transformed);

let assemblyFormat = "$target attr-dict";
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::func::FuncOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}

#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_COMMONEXTENSIONS
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ForeachThreadToGpuAndTranslationInfo :

The mapping of threads to gpu.thread_id is currently one-to-one and in order.
Only **bufferized** scf.foreach_thread are currently supported.
Only scf.foreach_thread distributed to *8at most 3 dimensions** are currently
Only scf.foreach_thread distributed to **at most 3 dimensions** are currently
supported.

Multiple scf.foreach_thread are supported per function in which case, the
Expand All @@ -39,7 +39,10 @@ def ForeachThreadToGpuAndTranslationInfo :
=============
This operation ignores non-Func ops and drops them in the return.

If all the scf::ForeachThread operations contained within the FuncOp
If any scf.foreach_thread with tensors is found, the transform definitely
fails.

If all the scf.foreach_thread operations contained within the FuncOp
referred to by the `target` PDLOperation lower to GPU properly, the
transform succeeds. Otherwise the transform definitely fails.

Expand Down
Loading

0 comments on commit ff04220

Please sign in to comment.