Skip to content

Commit

Permalink
dialects: Add memref.expand_shape and memref.collapse_shape ops (#2036)
Browse files Browse the repository at this point in the history
Simple PR to add support for memref.expand_shape and memref.collapse_shape

Note that these two ops have quite a bit of verification happening in
upstream, which this PR does not implement.
  • Loading branch information
JosseVanDelm authored Feb 1, 2024
1 parent cc655ec commit 9864f8d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/filecheck/dialects/memref/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,25 @@ builtin.module {
}

// CHECK: Alignment attribute 65 is not a power of 2

// -----

"func.func"() ({
%0 = "memref.alloc"() {"alignment" = 64 : i64, "operandSegmentSizes" = array<i32: 0, 0>}: () -> memref<10x2xindex>
%1 = "memref.collapse_shape"(%0) {"reassociation" = [[0 : i32 , 1 : i32]]} : (memref<10x2xindex>) -> memref<20xindex>
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "invalid_reassociation"} : () -> ()



// CHECK: Expected attribute i64

// -----

"func.func"() ({
%0 = "memref.alloc"() {"alignment" = 64 : i64, "operandSegmentSizes" = array<i32: 0, 0>}: () -> memref<20xindex>
%1 = "memref.expand_shape"(%0) {"reassociation" = [[0 : i32 , 1 : i32]]} : (memref<20xindex>) -> memref<2x10xindex>
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "invalid_reassociation"} : () -> ()

// CHECK: Expected attribute i64
4 changes: 4 additions & 0 deletions tests/filecheck/dialects/memref/memref_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ builtin.module {
%16 = "memref.alloc"(%12, %13, %14) {"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 3, 0>} : (index, index, index) -> memref<?x?x?xindex>
%17 = "memref.alloca"(%12) {"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 1, 0>} : (index) -> memref<?xindex>
%18 = "memref.alloca"(%12, %13, %14) {"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 3, 0>} : (index, index, index) -> memref<?x?x?xindex>
%19 = memref.collapse_shape %5 [[0, 1]] : memref<10x2xindex> into memref<20xindex>
%20 = memref.expand_shape %19 [[0, 1]] : memref<20xindex> into memref<2x10xindex>
"memref.dealloc"(%2) : (memref<1xindex>) -> ()
"memref.dealloc"(%5) : (memref<10x2xindex>) -> ()
"memref.dealloc"(%8) : (memref<1xindex>) -> ()
Expand Down Expand Up @@ -74,6 +76,8 @@ builtin.module {
// CHECK-NEXT: %{{.*}} = "memref.alloc"(%{{.*}}, %{{.*}}, %{{.*}}) <{"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 3, 0>}> : (index, index, index) -> memref<?x?x?xindex>
// CHECK-NEXT: %{{.*}} = "memref.alloca"(%{{.*}}) <{"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 1, 0>}> : (index) -> memref<?xindex>
// CHECK-NEXT: %{{.*}} = "memref.alloca"(%{{.*}}, %{{.*}}, %{{.*}}) <{"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 3, 0>}> : (index, index, index) -> memref<?x?x?xindex>
// CHECK-NEXT: %{{.*}} = memref.collapse_shape %{{.*}} [[0 : i64, 1 : i64]] : memref<10x2xindex> into memref<20xindex>
// CHECK-NEXT: %{{.*}} = memref.expand_shape %{{.*}} [[0 : i64, 1 : i64]] : memref<20xindex> into memref<2x10xindex>
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<10x2xindex>) -> ()
// CHECK-NEXT: "memref.dealloc"(%{{.*}}) : (memref<1xindex>) -> ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ func.func @memref_alloca_scope() {
}) : () -> ()
func.return
}

%v0 = "test.op"() : () -> (i32)
%i0 = "test.op"() : () -> (index)
%i1 = "test.op"() : () -> (index)
%m = "test.op"() : () -> (memref<2x3xi32>)
%r = "test.op"() : () -> (memref<10x3xi32>)
memref.store %v0, %m[%i0, %i1] : memref<2x3xi32>
memref.store %v0, %m[%i0, %i1] {"nontemporal" = false} : memref<2x3xi32>
memref.store %v0, %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
%v1 = memref.load %m[%i0, %i1] : memref<2x3xi32>
%v2 = memref.load %m[%i0, %i1] {"nontemporal" = false} : memref<2x3xi32>
%v3 = memref.load %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
%r1 = memref.expand_shape %r [[0, 1], [2]] : memref<10x3xi32> into memref<5x2x3xi32>
%r2 = memref.collapse_shape %r [[0, 1]] : memref<10x3xi32> into memref<30xi32>

// CHECK: module {
// CHECK-NEXT: func.func @memref_alloca_scope() {
Expand All @@ -27,10 +31,13 @@ memref.store %v0, %m[%i0, %i1] {"nontemporal" = true} : memref<2x3xi32>
// CHECK-NEXT: %1 = "test.op"() : () -> index
// CHECK-NEXT: %2 = "test.op"() : () -> index
// CHECK-NEXT: %3 = "test.op"() : () -> memref<2x3xi32>
// CHECK-NEXT: %4 = "test.op"() : () -> memref<10x3xi32>
// CHECK-NEXT: memref.store %0, %3[%1, %2] : memref<2x3xi32>
// CHECK-NEXT: memref.store %0, %3[%1, %2] : memref<2x3xi32>
// CHECK-NEXT: memref.store %0, %3[%1, %2] {nontemporal = true} : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.load %3[%1, %2] : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.load %3[%1, %2] : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.load %3[%1, %2] {nontemporal = true} : memref<2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.expand_shape %4 [[0, 1], [2]] : memref<10x3xi32> into memref<5x2x3xi32>
// CHECK-NEXT: %{{.*}} = memref.collapse_shape %4 [[0, 1]] : memref<10x3xi32> into memref<30xi32>
// CHECK-NEXT: }
34 changes: 34 additions & 0 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,38 @@ def from_memref(memref: Operation | SSAValue):
return Rank.build(operands=[memref], result_types=[IndexType()])


ReassociationAttr = ArrayAttr[
ArrayAttr[IntegerAttr[Annotated[IntegerType, IntegerType(64)]]]
]


class AlterShapeOp(IRDLOperation):
src: Operand = operand_def(MemRefType)
result: OpResult = result_def(MemRefType)
reassociation = prop_def(ReassociationAttr)
assembly_format = (
"$src $reassociation attr-dict `:` type($src) `into` type($result)"
)


@irdl_op_definition
class CollapseShapeOp(AlterShapeOp):
"""
https://mlir.llvm.org/docs/Dialects/MemRef/#memrefcollapse_shape-memrefcollapseshapeop
"""

name = "memref.collapse_shape"


@irdl_op_definition
class ExpandShapeOp(AlterShapeOp):
"""
https://mlir.llvm.org/docs/Dialects/MemRef/#memrefexpand_shape-memrefexpandshapeop
"""

name = "memref.expand_shape"


@irdl_op_definition
class ExtractAlignedPointerAsIndexOp(IRDLOperation):
name = "memref.extract_aligned_pointer_as_index"
Expand Down Expand Up @@ -668,6 +700,8 @@ def verify_(self) -> None:
AllocaScopeOp,
AllocaScopeReturnOp,
CopyOp,
CollapseShapeOp,
ExpandShapeOp,
Dealloc,
GetGlobal,
Global,
Expand Down

0 comments on commit 9864f8d

Please sign in to comment.