Skip to content

Commit

Permalink
Integrate llvm-project at 913f21ae5c46 (iree-org#14659)
Browse files Browse the repository at this point in the history
* Reset third_party/llvm-project:
913f21ae5c460d6fdd73ac7663534e73f8e19044 (2023-08-11 10:17:01 -0700):
[TextAPI] Express MH_SIM_SUPPORT in tbd files.

Cherry-pick:

* llvm/llvm-project@81c326c
* llvm/llvm-project@2b06650
* llvm/llvm-project@69a3c9c
* llvm/llvm-project@1a38843

Revert commit:

* llvm/llvm-project@dad9de0
  • Loading branch information
hanhanW authored Aug 15, 2023
1 parent 80bb1e0 commit 528fdd5
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 59 deletions.
10 changes: 6 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTensorPad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ void LLVMCPUTensorPadPass::runOnOperation() {
SmallVector<bool> noFold(linalgOp.getNumDpsInputs(), nofold);
noFold.append(linalgOp.getNumDpsInits(), false);

auto options = linalg::LinalgPaddingOptions()
.setPaddingDimensions(paddingDims)
.setPaddingValues(paddingValueAttributes)
.setPackPaddings(noFold);
auto options =
linalg::LinalgPaddingOptions()
.setPaddingDimensions(paddingDims)
.setPaddingValues(paddingValueAttributes)
.setPackPaddings(noFold)
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);
FailureOr<linalg::LinalgOp> maybePaddedLinalgOp =
linalg::padAndHoistLinalgOp(rewriter, linalgOp, options);
if (failed(maybePaddedLinalgOp)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ hal.executable @mma_fused_fp16 {
// CHECK: llvm.br
// CHECK-COUNT-2: nvvm.ldmatrix {{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-2: nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: nvvm.cp.async.wait.group 2
// CHECK-COUNT-2: nvvm.ldmatrix {{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
Expand Down Expand Up @@ -159,7 +159,7 @@ hal.executable @mma_fused_f32 {
// CHECK: nvvm.ldmatrix{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-4: llvm.extractvalue{{.*}} : !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-2: nvvm.mma.sync {{.*}} {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: nvvm.cp.async.wait.group 2
// CHECK: nvvm.ldmatrix{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
Expand All @@ -186,4 +186,4 @@ hal.executable @mma_fused_f32 {
// CHECK-COUNT: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr
// CHECK-COUNT: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32>
// CHECK-COUNT: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr
// CHECK-COUNT: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ hal.executable @mma_fused {
// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK-NOT: nvvm.wmma.mma
Expand Down Expand Up @@ -554,7 +554,7 @@ hal.executable @mma_fused_fp16 {
// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-2: nvvm.wmma.load{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)
// CHECK-COUNT-1: nvvm.wmma.mma
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK-NOT: nvvm.wmma.mma
Expand Down Expand Up @@ -634,7 +634,7 @@ hal.executable @mma_fused_fp16 {
// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-COUNT-2: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK-NOT: nvvm.wmma.mma
Expand Down Expand Up @@ -704,7 +704,7 @@ hal.executable @mma_fused_fp16 {
// CHECK: nvvm.cp.async.wait.group 3
// CHECK-COUNT-4: nvvm.wmma.load{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)
// CHECK-COUNT-2: nvvm.wmma.mma
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" {{.*}}, {{.*}}, {{.*}}, {{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
// CHECK: nvvm.cp.async.commit.group
// CHECK: llvm.br
// CHECK-NOT: nvvm.wmma.mma
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK: transform.structured.fuse_into_containing_op
// CHECK: transform.structured.tile_to_scf_for %{{.*}}[0, 0, 0, 16]
// CHECK: transform.structured.fuse_into_containing_op
// CHECK: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [1, 0, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: transform.structured.pad %{{.*}} {copy_back_op = "none", pack_paddings = [1, 0, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: transform.structured.match ops{["linalg.fill"]}
// CHECK: %[[RES:.+]] = get_producer_of_operand %{{.*}}[2]
// CHECK: transform.structured.rewrite_in_destination_passing_style %[[RES]]
Expand Down Expand Up @@ -102,7 +102,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",

// CHECK: transform.sequence failures(propagate) {
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [] tile_sizes [1, 128, 128](mapping = [#gpu.block<z>, #gpu.block<y>, #gpu.block<x>])
// CHECK: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [0, 1, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: transform.structured.pad %{{.*}} {copy_back_op = "none", pack_paddings = [0, 1, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: %[[RES:.+]] = get_producer_of_operand %{{.*}}[2]
// CHECK: transform.structured.rewrite_in_destination_passing_style %[[RES]]
// CHECK: %[[LHS:.+]] = get_producer_of_operand %{{.*}}[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK: transform.structured.fuse_into_containing_op
// CHECK: transform.iree.populate_workgroup_count_region_using_num_threads_slice
// CHECK: transform.structured.tile %{{.*}}[0, 0, 16]
// CHECK: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: transform.structured.pad %{{.*}} {copy_back_op = "none", pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// CHECK: transform.structured.hoist_pad %{{.}} by 1 loops
// CHECK: transform.structured.insert_slice_to_copy %{{.*}} : (!transform.any_op) -> !transform.any_op
// CHECK: transform.structured.tile_to_forall_op %{{.*}} num_threads [32, 4] tile_sizes [](mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>])
Expand Down Expand Up @@ -135,7 +135,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// WITH_OPTIONS: transform.iree.populate_workgroup_count_region_using_num_threads_slice
// The tiling is affected by td-matmul-strategy-reduc-size: 8.
// WITH_OPTIONS: transform.structured.tile %{{.*}}[0, 0, 8]
// WITH_OPTIONS: transform.structured.pad %{{.*}} {copy_back = false, pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// WITH_OPTIONS: transform.structured.pad %{{.*}} {copy_back_op = "none", pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
// WITH_OPTIONS: transform.structured.hoist_pad %{{.}} by 1 loops
// WITH_OPTIONS: transform.structured.insert_slice_to_copy %{{.*}} : (!transform.any_op) -> !transform.any_op
// WITH_OPTIONS: transform.structured.tile_to_forall_op %{{.*}} num_threads [64, 2] tile_sizes [](mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>])
Expand Down Expand Up @@ -326,7 +326,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",

// Make sure we do not canonicalize because the result is still aligned.
// CHECK-NEXT: transform.structured.pad %tiled_linalg_op
// CHECK-SAME: copy_back = false
// CHECK-SAME: copy_back_op = "none"
// CHECK-SAME: pack_paddings = [1, 1, 1]
// CHECK-SAME: padding_dimensions = [0, 1, 2]
// CHECK-SAME: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
Expand Down Expand Up @@ -398,7 +398,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",

// Make sure we do not canonicalize if the result is aligned to avoid folding the extract_slice on the iterator.
// CHECK-NEXT: transform.structured.pad %tiled_linalg_op
// CHECK-SAME: copy_back = false
// CHECK-SAME: copy_back_op = "none"
// CHECK-SAME: pack_paddings = [1, 1, 1]
// CHECK-SAME: padding_dimensions = [0, 1, 2]
// CHECK-SAME: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Value mlir::iree_compiler::buildPad(
b.getI64ArrayAttr(padToMultipleOf),
b.getI64ArrayAttr(packingDimensions),
b.getArrayAttr(transposeAttrs),
/*copyBack=*/b.getBoolAttr(false))
/*copyBack=*/b.getStringAttr("none"))
->getResult(0);
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser,

result.addRegion(std::move(bodyRegion));
result.addRegion(std::move(workloadCountRegion));
result.addAttribute("operand_segment_sizes",
result.addAttribute("operandSegmentSizes",
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(allOperands.size()),
static_cast<int32_t>(workloadOperands.size())}));
Expand All @@ -422,7 +422,7 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser,

void DispatchRegionOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 1> elidedAttrs;
elidedAttrs.push_back("operand_segment_sizes");
elidedAttrs.push_back("operandSegmentSizes");
if (!getWorkload().empty()) {
p << "[" << getWorkload() << "]";
}
Expand Down
30 changes: 8 additions & 22 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1824,14 +1824,13 @@ std::pair<unsigned, unsigned> AsyncExecuteOp::getTiedResultsIndexAndLength() {
}

OperandRange
AsyncExecuteOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
AsyncExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && index.value() == 0 && "invalid region index");
return getResourceOperands();
}

void AsyncExecuteOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
Expand Down Expand Up @@ -1980,14 +1979,13 @@ LogicalResult AsyncConcurrentOp::verify() {
}

OperandRange
AsyncConcurrentOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
AsyncConcurrentOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && index.value() == 0 && "invalid region index");
return getResourceOperands();
}

void AsyncConcurrentOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
Expand Down Expand Up @@ -2792,14 +2790,13 @@ LogicalResult CmdExecuteOp::verify() {
}

OperandRange
CmdExecuteOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
CmdExecuteOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
assert(index && index.value() == 0 && "invalid region index");
return getResourceOperands();
}

void CmdExecuteOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
Expand Down Expand Up @@ -2868,8 +2865,7 @@ LogicalResult CmdSerialOp::verify() {
}

void CmdSerialOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
Expand All @@ -2894,8 +2890,7 @@ LogicalResult CmdConcurrentOp::verify() {
}

void CmdConcurrentOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
// Unconditional control flow into the region and back to the parent, so
// return the correct RegionSuccessor purely based on the index being None or
// 0.
Expand Down Expand Up @@ -3093,15 +3088,6 @@ LogicalResult BindingSubspanOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// stream.return
//===----------------------------------------------------------------------===//

MutableOperandRange
ReturnOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
return getOperandsMutable();
}

//===----------------------------------------------------------------------===//
// stream.yield
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 4 additions & 10 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2137,7 +2137,7 @@ def Stream_AsyncExecuteOp : Stream_Op<"async.execute", [
AttrSizedOperandSegments,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getSuccessorEntryOperands",
"getEntrySuccessorOperands",
]>,
SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
Stream_AffinityOp,
Expand Down Expand Up @@ -2232,7 +2232,7 @@ def Stream_AsyncConcurrentOp : Stream_Op<"async.concurrent", [
AttrSizedOperandSegments,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getSuccessorEntryOperands",
"getEntrySuccessorOperands",
]>,
SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
Stream_AffinityOp,
Expand Down Expand Up @@ -2815,7 +2815,7 @@ def Stream_CmdExecuteOp : Stream_Op<"cmd.execute", [
AttrSizedOperandSegments,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getSuccessorEntryOperands",
"getEntrySuccessorOperands",
]>,
SingleBlockImplicitTerminator<"IREE::Stream::YieldOp">,
Stream_AffinityOp,
Expand Down Expand Up @@ -3588,9 +3588,6 @@ def Stream_ReturnOp : Stream_Op<"return", [
"IREE::Stream::ExecutableExportOp",
]>,
Pure,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
"getMutableSuccessorOperands",
]>,
ReturnLike,
Terminator,
]> {
Expand Down Expand Up @@ -3618,10 +3615,7 @@ def Stream_YieldOp : Stream_Op<"yield", [
"IREE::Stream::CmdConcurrentOp",
]>,
Pure,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
"getMutableSuccessorOperands",
]>,
ReturnLike,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
SameVariadicOperandSize,
Util_SizeAwareOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ TraversalResult Explorer::walkIncomingBranchOperands(
regionOp.getSuccessorRegions(/*index=*/std::nullopt, entrySuccessors);
for (auto &entrySuccessor : entrySuccessors) {
if (fn(regionOp->getBlock(),
regionOp.getSuccessorEntryOperands(
regionOp.getEntrySuccessorOperands(
entrySuccessor.getSuccessor()->getRegionNumber()))
.wasInterrupted()) {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void copyOperationAttrs(Operation *oldOp, Operation *newOp) {
for (NamedAttribute oldAttr : oldOp->getAttrs()) {
// Don't copy segment attributes as these correspond to the number operands,
// which may be different.
if (oldAttr.getName() == "operand_segment_sizes" ||
oldAttr.getName() == "result_segment_sizes")
if (oldAttr.getName() == "operandSegmentSizes" ||
oldAttr.getName() == "resultSegmentSizes")
continue;

newOp->setAttr(oldAttr.getName(), oldAttr.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ void linalg::transform::LinalgTransformDialect::initialize() {
//===---------------------------------------------------------------------===//

void linalg::transform::ScopeOp::getSuccessorRegions(
std::optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
if (index)
regions.emplace_back(getResults());
else
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Loading

0 comments on commit 528fdd5

Please sign in to comment.