Skip to content

Commit

Permalink
dialects: Add lowering pass to convert snitch-runtime to external fun…
Browse files Browse the repository at this point in the history
…c.funcs (#1085)

This commit request does two things:
* it introduces a lowering pass that lowers all snrt runtime operations,
to a currently present to func.func, with an accompanying external
function call for use with mlir-opt and clang -x ir. (interop tests with
mlir-opt and clang are not included).
* It slightly refactors snitch runtime with the recently introduced
generics to reduce code size (1D and 2D DMA operations now have
their own base classes for both regular and `wideptr` variants).
  • Loading branch information
JosseVanDelm authored Jun 8, 2023
1 parent 1c22562 commit 98f75ff
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 95 deletions.
42 changes: 42 additions & 0 deletions tests/filecheck/dialects/snitch_runtime/lower_snrt_to_func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: xdsl-opt -p lower-snrt-to-func %s | filecheck %s
"builtin.module"() ({
"func.func"() ({
// Runtime Info Getters
%cluster_num = "snrt.cluster_num"() : () -> i32
// CHECK: %cluster_num = "func.call"() {"callee" = @snrt_cluster_num} : () -> i32

// Barriers
"snrt.cluster_hw_barrier"() : () -> ()
// CHECK: "func.call"() {"callee" = @snrt_cluster_hw_barrier} : () -> ()
// DMA functions
"snrt.dma_wait_all"() : () -> ()
// CHECK: "func.call"() {"callee" = @snrt_dma_wait_all} : () -> ()

%dst_64 = "arith.constant"() {"value" = 100 : i64} : () -> i64
%src_64 = "arith.constant"() {"value" = 0 : i64} : () -> i64
%size = "arith.constant"() {"value" = 100 : index} : () -> index
%transfer_id = "snrt.dma_start_1d_wideptr"(%dst_64, %src_64, %size) : (i64, i64, index) -> i32
// CHECK: %transfer_id = "func.call"(%dst_64, %src_64, %size) {"callee" = @snrt_dma_start_1d_wideptr} : (i64, i64, index) -> i32

%dst_32 = "arith.constant"() {"value" = 100: i32} : () -> i32
%src_32 = "arith.constant"() {"value" = 0: i32} : () -> i32
%size_2 = "arith.constant"() {"value" = 100: index} : () -> index
%transfer_id_2 = "snrt.dma_start_1d"(%dst_32, %src_32, %size_2) : (i32, i32, index) -> i32
// CHECK: %transfer_id_2 = "func.call"(%dst_32, %src_32, %size_2) {"callee" = @snrt_dma_start_1d} : (i32, i32, index) -> i32
%repeat = "arith.constant"() {"value" = 1: index} : () -> index
%src_stride = "arith.constant"() {"value" = 1: index} : () -> index
%dst_stride = "arith.constant"() {"value" = 1: index} : () -> index
%transfer_id_3 = "snrt.dma_start_2d_wideptr"(%dst_64, %src_64, %dst_stride, %src_stride, %size_2, %repeat) : (i64, i64, index, index, index, index) -> i32
// CHECK: %transfer_id_3 = "func.call"(%dst_64, %src_64, %dst_stride, %src_stride, %size_2, %repeat) {"callee" = @snrt_dma_start_2d_wideptr} : (i64, i64, index, index, index, index) -> i32
%transfer_id_4 = "snrt.dma_start_2d"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) : (i32, i32, index, index, index, index) -> i32
// CHECK: %transfer_id_4 = "func.call"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) {"callee" = @snrt_dma_start_2d} : (i32, i32, index, index, index, index) -> i32
"func.return"() : () -> ()
}) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "private"} : () -> ()
// CHECK: func.func private @snrt_cluster_num() -> i32
// CHECK: func.func private @snrt_cluster_hw_barrier() -> ()
// CHECK: func.func private @snrt_dma_wait_all() -> ()
// CHECK: func.func private @snrt_dma_start_1d_wideptr(i64, i64, index) -> i32
// CHECK: func.func private @snrt_dma_start_1d(i32, i32, index) -> i32
// CHECK: func.func private @snrt_dma_start_2d_wideptr(i64, i64, index, index, index, index) -> i32
// CHECK: func.func private @snrt_dma_start_2d(i32, i32, index, index, index, index) -> i32
}) : () -> ()
60 changes: 31 additions & 29 deletions tests/filecheck/dialects/snitch_runtime/snitch_runtime_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
// RUN: xdsl-opt %s | xdsl-opt --print-op-generic | filecheck %s
"builtin.module"() ({
// Barriers
"snrt.cluster_hw_barrier"() : () -> ()
// CHECK: "snrt.cluster_hw_barrier"() : () -> ()
"func.func"() ({
// Barriers
"snrt.cluster_hw_barrier"() : () -> ()
// CHECK: "snrt.cluster_hw_barrier"() : () -> ()

// Runtime Info Getters
%cluster_num = "snrt.cluster_num"() : () -> ui32
// CHECK: %{{.*}} = "snrt.cluster_num"() : () -> ui32

// DMA Operations
%dst_64 = "arith.constant"() {"value" = 100 : ui64} : () -> ui64
%src_64 = "arith.constant"() {"value" = 0 : ui64} : () -> ui64
%size = "arith.constant"() {"value" = 100 : index} : () -> index
%transfer_id = "snrt.dma_start_1d_wideptr"(%dst_64, %src_64, %size) : (ui64, ui64, index) -> ui32
// CHECK: %transfer_id = "snrt.dma_start_1d_wideptr"(%dst_64, %src_64, %size) : (ui64, ui64, index) -> ui32
%dst_32 = "arith.constant"() {"value" = 100: ui32} : () -> ui32
%src_32 = "arith.constant"() {"value" = 0: ui32} : () -> ui32
%size_2 = "arith.constant"() {"value" = 100: index} : () -> index
%transfer_id_2 = "snrt.dma_start_1d"(%dst_32, %src_32, %size_2) : (ui32, ui32, index) -> ui32
// CHECK: %transfer_id_2 = "snrt.dma_start_1d"(%dst_32, %src_32, %size_2) : (ui32, ui32, index) -> ui32
"snrt.dma_wait"(%transfer_id) : (ui32) -> ()
// CHECK: "snrt.dma_wait"(%transfer_id) : (ui32) -> ()
%repeat = "arith.constant"() {"value" = 1: index} : () -> index
%src_stride = "arith.constant"() {"value" = 1: index} : () -> index
%dst_stride = "arith.constant"() {"value" = 1: index} : () -> index
%transfer_id_3 = "snrt.dma_start_2d_wideptr"(%dst_64, %src_64, %dst_stride, %src_stride, %size, %repeat) : (ui64, ui64, index, index, index, index) -> ui32
// CHECK: transfer_id_3 = "snrt.dma_start_2d_wideptr"(%dst_64, %src_64, %dst_stride, %src_stride, %size, %repeat) : (ui64, ui64, index, index, index, index) -> ui32
%transfer_id_4 = "snrt.dma_start_2d"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) : (ui32, ui32, index, index, index, index) -> ui32
// CHECK: transfer_id_4 = "snrt.dma_start_2d"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) : (ui32, ui32, index, index, index, index) -> ui32
"snrt.dma_wait_all"() : () -> ()
// CHECK: "snrt.dma_wait_all"() : () -> ()
// Runtime Info Getters
%cluster_num = "snrt.cluster_num"() : () -> i32
// CHECK: %{{.*}} = "snrt.cluster_num"() : () -> i32

// DMA Operations
%dst_64 = "arith.constant"() {"value" = 100 : i64} : () -> i64
%src_64 = "arith.constant"() {"value" = 0 : i64} : () -> i64
%size = "arith.constant"() {"value" = 100 : index} : () -> index
%transfer_id = "snrt.dma_start_1d_wideptr"(%dst_64, %src_64, %size) : (i64, i64, index) -> i32
// CHECK: %transfer_id = "snrt.dma_start_1d_wideptr"(%dst_64, %src_64, %size) : (i64, i64, index) -> i32
%dst_32 = "arith.constant"() {"value" = 100: i32} : () -> i32
%src_32 = "arith.constant"() {"value" = 0: i32} : () -> i32
%size_2 = "arith.constant"() {"value" = 100: index} : () -> index
%transfer_id_2 = "snrt.dma_start_1d"(%dst_32, %src_32, %size_2) : (i32, i32, index) -> i32
// CHECK: %transfer_id_2 = "snrt.dma_start_1d"(%dst_32, %src_32, %size_2) : (i32, i32, index) -> i32
"snrt.dma_wait"(%transfer_id) : (i32) -> ()
// CHECK: "snrt.dma_wait"(%transfer_id) : (i32) -> ()
%repeat = "arith.constant"() {"value" = 1: index} : () -> index
%src_stride = "arith.constant"() {"value" = 1: index} : () -> index
%dst_stride = "arith.constant"() {"value" = 1: index} : () -> index
%transfer_id_3 = "snrt.dma_start_2d_wideptr"(%dst_64, %src_64, %dst_stride, %src_stride, %size, %repeat) : (i64, i64, index, index, index, index) -> i32
// CHECK: transfer_id_3 = "snrt.dma_start_2d_wideptr"(%dst_64, %src_64, %dst_stride, %src_stride, %size, %repeat) : (i64, i64, index, index, index, index) -> i32
%transfer_id_4 = "snrt.dma_start_2d"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) : (i32, i32, index, index, index, index) -> i32
// CHECK: transfer_id_4 = "snrt.dma_start_2d"(%dst_32, %src_32, %dst_stride, %src_stride, %size_2, %repeat) : (i32, i32, index, index, index, index) -> i32
"snrt.dma_wait_all"() : () -> ()
// CHECK: "snrt.dma_wait_all"() : () -> ()
"func.return"() : () -> ()
}) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "private"} : () -> ()
}) : () -> ()
119 changes: 53 additions & 66 deletions xdsl/dialects/snitch_runtime.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from abc import ABC
from xdsl.irdl import irdl_op_definition, IRDLOperation, Operand, Operation, SSAValue
from xdsl.irdl import (
irdl_op_definition,
IRDLOperation,
Operand,
Operation,
SSAValue,
ConstraintVar,
Attribute,
)
from xdsl.ir import OpResult, Dialect
from xdsl.dialects.builtin import IntegerType, Signedness, IndexType
from typing import Annotated
from xdsl.dialects.builtin import i32, i64, IndexType
from typing import Annotated, Generic, TypeVar

u32 = IntegerType(data=32, signedness=Signedness.UNSIGNED)
u64 = IntegerType(data=64, signedness=Signedness.UNSIGNED)
tx_id = u32
tx_id = i32


class SnitchRuntimeBaseOp(IRDLOperation, ABC):
Expand All @@ -27,12 +33,12 @@ class SnitchRuntimeGetInfo(SnitchRuntimeBaseOp, ABC):
A base class for snitch runtime functions that get a certain value at runtime
"""

result: Annotated[OpResult, u32]
result: Annotated[OpResult, i32]

def __init__(
self,
):
super().__init__(operands=[], result_types=[u32])
super().__init__(operands=[], result_types=[i32])


class SnitchRuntimeBarrier(SnitchRuntimeBaseOp, ABC):
Expand Down Expand Up @@ -64,36 +70,17 @@ class ClusterHwBarrierOp(SnitchRuntimeBarrier):
name = "snrt.cluster_hw_barrier"


@irdl_op_definition
class DmaStart1DWideptrOp(SnitchRuntimeBaseOp):
"""
Initiate an asynchronous 1D DMA transfer with wide 64-bit pointers
"""

name = "snrt.dma_start_1d_wideptr"
src: Annotated[Operand, u64]
dst: Annotated[Operand, u64]
size: Annotated[Operand, IndexType]
transfer_id: Annotated[OpResult, tx_id]

def __init__(
self,
src: Operation | SSAValue,
dst: Operation | SSAValue,
size: Operation | SSAValue,
):
super().__init__(operands=[src, dst, size], result_types=[tx_id])
_T = TypeVar("_T", bound=Attribute)


@irdl_op_definition
class DmaStart1DOp(SnitchRuntimeBaseOp):
class DmaStart1DBaseOp(SnitchRuntimeBaseOp, Generic[_T], ABC):
"""
Initiate an asynchronous 1D DMA transfer
"""

name = "snrt.dma_start_1d"
dst: Annotated[Operand, u32]
src: Annotated[Operand, u32]
T = Annotated[Attribute, ConstraintVar("T"), _T]
dst: Annotated[Operand, T]
src: Annotated[Operand, T]
size: Annotated[Operand, IndexType]
transfer_id: Annotated[OpResult, tx_id]

Expand All @@ -106,15 +93,14 @@ def __init__(
super().__init__(operands=[dst, src, size], result_types=[tx_id])


@irdl_op_definition
class DmaStart2DWideptrOp(SnitchRuntimeBaseOp):
class DmaStart2DBaseOp(SnitchRuntimeBaseOp, Generic[_T], ABC):
"""
Initiate an asynchronous 2D DMA transfer with wide 64-bit pointers
Generic base class for starting asynchronous 2D DMA transfers
"""

name = "snrt.dma_start_2d_wideptr"
dst: Annotated[Operand, u64]
src: Annotated[Operand, u64]
T = Annotated[Attribute, ConstraintVar("T"), _T]
dst: Annotated[Operand, T]
src: Annotated[Operand, T]
dst_stride: Annotated[Operand, IndexType]
src_stride: Annotated[Operand, IndexType]
size: Annotated[Operand, IndexType]
Expand All @@ -137,33 +123,39 @@ def __init__(


@irdl_op_definition
class DmaStart2DOp(SnitchRuntimeBaseOp):
class DmaStart1DOp(DmaStart1DBaseOp[Annotated[Attribute, i32]]):
"""
Initiate an asynchronous 1D DMA transfer with 32-bits pointers
"""

name = "snrt.dma_start_1d"


@irdl_op_definition
class DmaStart1DWideptrOp(DmaStart1DBaseOp[Annotated[Attribute, i64]]):
"""
Initiate an asynchronous 1D DMA transfer with 64-bits wide pointers
"""

name = "snrt.dma_start_1d_wideptr"


@irdl_op_definition
class DmaStart2DOp(DmaStart2DBaseOp[Annotated[Attribute, i32]]):
"""
Initiate an asynchronous 2D DMA transfer
Initiate an asynchronous 2D DMA transfer with 32-bits pointers
"""

name = "snrt.dma_start_2d"
dst: Annotated[Operand, u32]
src: Annotated[Operand, u32]
dst_stride: Annotated[Operand, IndexType]
src_stride: Annotated[Operand, IndexType]
size: Annotated[Operand, IndexType]
repeat: Annotated[Operand, IndexType]
transfer_id: Annotated[OpResult, tx_id]

def __init__(
self,
dst: Operation | SSAValue,
src: Operation | SSAValue,
dst_stride: Operation | SSAValue,
src_stride: Operation | SSAValue,
size: Operation | SSAValue,
repeat: Operation | SSAValue,
):
super().__init__(
operands=[dst, src, dst_stride, src_stride, size, repeat],
result_types=[tx_id],
)

@irdl_op_definition
class DmaStart2DWideptrOp(DmaStart2DBaseOp[Annotated[Attribute, i64]]):
"""
Initiate an asynchronous 2D DMA transfer with 64-bits wide pointers
"""

name = "snrt.dma_start_2d_wideptr"


@irdl_op_definition
Expand All @@ -180,18 +172,13 @@ def __init__(self, transfer_id: Operation | SSAValue):


@irdl_op_definition
class DmaWaitAllOp(SnitchRuntimeBaseOp):
class DmaWaitAllOp(SnitchRuntimeBarrier):
"""
Block until all operations on the DMA cease
"""

name = "snrt.dma_wait_all"

def __init__(
self,
):
super().__init__(operands=[], result_types=[])


SnitchRuntime = Dialect(
[
Expand Down
Loading

0 comments on commit 98f75ff

Please sign in to comment.