Skip to content

Commit

Permalink
Adding an optional workgroup count region to flow.dispatch.workgroups. (
Browse files Browse the repository at this point in the history
iree-org#9428)

This completes the plumbing as the region is already supported after
outlining.

I hate the syntax, but until we settle on where distribution happens
and potentially split the op into distributed/non-distributed variants
(like the old flow.dispatch.region we had) this will suffice.
  • Loading branch information
benvanik authored Jun 10, 2022
1 parent fdabffc commit 796d768
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 11 deletions.
91 changes: 81 additions & 10 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,8 @@ static Optional<BlockArgument> getBindingArgument(Value v) {
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//

static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser,
Region &body) {
if (failed(parser.parseOptionalKeyword("workgroups"))) {
// Omitted.
return success();
}

static ParseResult parseWorkgroupCountRegionWithoutKeyword(OpAsmParser &parser,
Region &body) {
SmallVector<OpAsmParser::Argument> args;
if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
/*allowType=*/true,
Expand Down Expand Up @@ -226,10 +221,11 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser,
return success();
}

static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op,
Region &body) {
static void printWorkgroupCountRegionWithoutKeyword(OpAsmPrinter &p,
Operation *op,
Region &body) {
if (body.empty()) return;
p << "workgroups(";
p << "(";
auto args = body.getArguments();
for (unsigned i = 0; i < args.size(); ++i) {
if (i > 0) p << ", ";
Expand All @@ -243,6 +239,38 @@ static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op,
/*printBlockTerminators=*/true);
}

// TODO(benvanik): make these keywords required or consistent.

static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser,
Region &body) {
if (failed(parser.parseOptionalKeyword("workgroups"))) {
return success(); // Omitted.
}
return parseWorkgroupCountRegionWithoutKeyword(parser, body);
}

static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op,
Region &body) {
if (body.empty()) return;
p << "workgroups";
printWorkgroupCountRegionWithoutKeyword(p, op, body);
}

static ParseResult parseDispatchWorkgroupsCountRegion(OpAsmParser &parser,
Region &body) {
if (failed(parser.parseOptionalKeyword("count"))) {
return success(); // Omitted.
}
return parseWorkgroupCountRegionWithoutKeyword(parser, body);
}

static void printDispatchWorkgroupsCountRegion(OpAsmPrinter &p, Operation *op,
Region &body) {
if (body.empty()) return;
p << " count";
printWorkgroupCountRegionWithoutKeyword(p, op, body);
}

//===----------------------------------------------------------------------===//
// flow.dispatch.tie_shape
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -574,6 +602,10 @@ void DispatchWorkgroupsOp::build(OpBuilder &builder, OperationState &state,
workgroupBody->addArgument(type, state.location);
}
assert(std::next(workgroupBody->begin()) == workgroupBody->end());

// NOTE: workgroup count region is empty; callers are expected to populate it
// if they want it.
state.addRegion();
}

static ParseResult parseDispatchWorkgroupBody(OpAsmParser &parser,
Expand Down Expand Up @@ -619,6 +651,38 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op,
/*printBlockTerminators=*/true);
}

LogicalResult verifyWorkgroupCountRegion(Operation *op, ValueRange workload,
Region &region) {
// Verify the workload operands match the expected capture args.
if (workload.size() != region.getNumArguments()) {
return op->emitOpError()
<< "workload operands and workgroup count args mismatch ("
<< workload.size() << " vs " << region.getNumArguments() << ")";
}
for (auto it : llvm::enumerate(llvm::zip(workload, region.getArguments()))) {
auto workloadValue = std::get<0>(it.value());
auto capturedArg = std::get<1>(it.value());
if (workloadValue.getType() != capturedArg.getType()) {
return op->emitOpError()
<< "workload value " << it.index() << " type mismatch; operand is "
<< workloadValue.getType() << " but region captures "
<< capturedArg.getType();
}
}

// Verify the return ops all provide XYZ values.
for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
if (returnOp.getNumOperands() != 3 ||
!llvm::all_of(returnOp.getOperandTypes(),
[](Type type) { return type.isIndex(); })) {
return returnOp.emitOpError() << "workgroup count region must return "
"the XYZ dimension counts";
}
}

return success();
}

LogicalResult DispatchWorkgroupsOp::verify() {
Operation *op = getOperation();

Expand All @@ -644,6 +708,13 @@ LogicalResult DispatchWorkgroupsOp::verify() {
if (failed(verifyIOType(type))) return failure();
}

// Workgroup count region is optional.
if (!workgroup_count().empty()) {
if (failed(verifyWorkgroupCountRegion(op, workload(), workgroup_count()))) {
return failure();
}
}

return success();
}

Expand Down
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
Variadic<AnyType>:$results
);

let regions = (region AnyRegion:$workgroup_body);
let regions = (region
AnyRegion:$workgroup_body,
AnyRegion:$workgroup_count
);

let assemblyFormat = [{
(`[` $workload^ `]`)? ``
Expand All @@ -94,6 +97,7 @@ def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
custom<DispatchWorkgroupBody>(ref(type($operands)),
ref(type($results)),
$workgroup_body)
`` custom<DispatchWorkgroupsCountRegion>($workgroup_count)
}];

let skipDefaultBuilders = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,32 @@ func.func @inplaceDispatch(
// CHECK: return %[[OUTER_RET0]] : tensor<?x4xf32>
return %0 : tensor<?x4xf32>
}

// -----

// CHECK-LABEL: @dispatchWithCountRegion
// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xi32>)
func.func @dispatchWithCountRegion(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-DAG: %[[WORKGROUP_COUNT_X:.+]] = arith.constant 100
%x = arith.constant 100 : index
// CHECK-DAG: %[[WORKGROUP_COUNT_Y:.+]] = arith.constant 50
%y = arith.constant 50 : index
// CHECK: %[[OUTER_RET0:.+]] = flow.dispatch.workgroups[
// CHECK-SAME: %[[WORKGROUP_COUNT_X]], %[[WORKGROUP_COUNT_Y]]
// CHECK-SAME: ](%[[ARG0]]) : (tensor<4xi32>) -> %[[ARG0]] =
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<4xi32>) -> %arg0 =
// CHECK-NEXT: (%{{.+}}: !flow.dispatch.tensor<readwrite:4xi32>) {
(%arg0_capture: !flow.dispatch.tensor<readwrite:4xi32>) {
// CHECK-NEXT: flow.return
flow.return
// CHECK-NEXT: count(%[[X_CAPTURE:.+]]: index, %[[Y_CAPTURE:.+]]: index)
// CHECK-SAME: -> (index, index, index)
} count(%x_capture: index, %y_capture: index) -> (index, index, index) {
// CHECK-NEXT: %[[Z:.+]] = arith.constant 1
%z = arith.constant 1 : index
// CHECK-NEXT: flow.return %[[X_CAPTURE]], %[[Y_CAPTURE]], %[[Z]]
flow.return %x_capture, %y_capture, %z : index, index, index
}
// CHECK: return %[[OUTER_RET0]] : tensor<4xi32>
return %0 : tensor<4xi32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ static LogicalResult outlineDispatchWorkgroupsOp(
regionOp.getLoc(), workgroupFuncOp.getName(),
SymbolRefAttr::get(workgroupFuncOp));

// Move over the workgroup count region, if present.
if (!regionOp.workgroup_count().empty()) {
exportOp.workgroup_count().takeBody(regionOp.workgroup_count());
}

// Finally convert the dispatch region into a dispatch to the outlined func.
return convertToDispatchOp(regionOp, executableOp, exportOp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,19 @@ func.func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor<?x?x1024
// CHECK-NEXT: return %[[RET0]]
return %ret0 : tensor<?x?x1024xf32>
}

// -----

// CHECK-LABEL: func.func @dispatchWithCountRegion
func.func @dispatchWithCountRegion(%arg0: tensor<4xi32>) -> tensor<4xi32> {
%x = arith.constant 100 : index
%y = arith.constant 50 : index
%0 = flow.dispatch.workgroups[%x, %y](%arg0) : (tensor<4xi32>) -> %arg0 =
(%arg0_capture: !flow.dispatch.tensor<readwrite:4xi32>) {
flow.return
} count(%x_capture: index, %y_capture: index) -> (index, index, index) {
%z = arith.constant 1 : index
flow.return %x_capture, %y_capture, %z : index, index, index
}
return %0 : tensor<4xi32>
}

0 comments on commit 796d768

Please sign in to comment.