Skip to content

Commit

Permalink
dialects: Add optional alignment property to memref.global (#1976)
Browse files Browse the repository at this point in the history
This is supposed to be a i64 value that has to be a power of two.

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
JosseVanDelm and superlopuh authored Jan 16, 2024
1 parent 252cf2d commit 82b5663
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tests/filecheck/dialects/memref/invalid_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: xdsl-opt %s --parsing-diagnostics --verify-diagnostics --split-input-file | filecheck %s

builtin.module {
"memref.global"() {"alignment" = 64 : i32, "sym_name" = "wrong_alignment_type", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
}

// CHECK: Expected attribute i64 but got i32

// -----

builtin.module {
"memref.global"() {"alignment" = 65 : i64, "sym_name" = "non_power_of_two_alignment", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
}

// CHECK: Alignment attribute 65 is not a power of 2
2 changes: 2 additions & 0 deletions tests/filecheck/dialects/memref/memref_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ builtin.module {
func.return
}
"memref.global"() {"sym_name" = "g", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
"memref.global"() {"alignment" = 64 : i64, "sym_name" = "g_with_alignment", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
func.func private @memref_test() {
%0 = "memref.get_global"() {"name" = @g} : () -> memref<1xindex>
%1 = arith.constant 0 : index
Expand Down Expand Up @@ -46,6 +47,7 @@ builtin.module {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: "memref.global"() <{"sym_name" = "g", "sym_visibility" = "public", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>}> : () -> ()
// CHECK-NEXT: "memref.global"() <{"sym_name" = "g_with_alignment", "sym_visibility" = "public", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "alignment" = 64 : i64}> : () -> ()
// CHECK-NEXT: func.func private @memref_test() {
// CHECK-NEXT: %{{.*}} = "memref.get_global"() <{"name" = @g}> : () -> memref<1xindex>
// CHECK-NEXT: %{{.*}} = arith.constant 0 : index
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: xdsl-opt %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s

"builtin.module"() ({
"memref.global"() {"alignment" = 64 : i64, "sym_name" = "g_with_alignment", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
"memref.global"() {"sym_name" = "g", "type" = memref<1xindex>, "initial_value" = dense<0> : tensor<1xindex>, "sym_visibility" = "public"} : () -> ()
"func.func"() ({
%0 = "memref.get_global"() {"name" = @g} : () -> memref<1xindex>
Expand Down Expand Up @@ -37,6 +38,7 @@


// CHECK: "builtin.module"() ({
// CHECK-NEXT: "memref.global"() <{"alignment" = 64 : i64, "initial_value" = dense<0> : tensor<1xindex>, "sym_name" = "g_with_alignment", "sym_visibility" = "public", "type" = memref<1xindex>}> : () -> ()
// CHECK-NEXT: "memref.global"() <{"initial_value" = dense<0> : tensor<1xindex>, "sym_name" = "g", "sym_visibility" = "public", "type" = memref<1xindex>}> : () -> ()
// CHECK-NEXT: "func.func"() <{"function_type" = () -> (), "sym_name" = "memref_test", "sym_visibility" = "private"}> ({
// CHECK-NEXT: %0 = "memref.get_global"() <{"name" = @g}> : () -> memref<1xindex>
Expand Down
23 changes: 22 additions & 1 deletion tests/utils/test_bitwise_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pytest

from xdsl.utils.bitwise_casts import convert_f32_to_u32, convert_u32_to_f32
from xdsl.utils.bitwise_casts import (
convert_f32_to_u32,
convert_u32_to_f32,
is_power_of_two,
)


# http://bartaz.github.io/ieee754-visualization/
Expand All @@ -18,3 +22,20 @@
def test_float_bitwise_casts(i: int, f: float):
assert convert_f32_to_u32(f) == i
assert struct.pack(">f", convert_u32_to_f32(i)) == struct.pack(">f", f)


@pytest.mark.parametrize(
"i, p",
[
(-2, False),
(-1, False),
(0, False),
(1, True),
(2, True),
(3, False),
(4, True),
(5, False),
],
)
def test_is_power_of_two(i: int, p: bool):
assert is_power_of_two(i) == p
17 changes: 16 additions & 1 deletion xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
from typing import TYPE_CHECKING, Annotated, Generic, TypeAlias, TypeVar, cast

from typing_extensions import Self

Expand Down Expand Up @@ -57,6 +57,7 @@
IsTerminator,
SymbolOpInterface,
)
from xdsl.utils.bitwise_casts import is_power_of_two
from xdsl.utils.deprecation import deprecated_constructor
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
Expand Down Expand Up @@ -472,6 +473,7 @@ class Global(IRDLOperation):
sym_visibility: StringAttr = prop_def(StringAttr)
type: Attribute = prop_def(Attribute)
initial_value: Attribute = prop_def(Attribute)
alignment = opt_prop_def(IntegerAttr[Annotated[IntegerType, IntegerType(64)]])

traits = frozenset([SymbolOpInterface()])

Expand All @@ -484,20 +486,33 @@ def verify_(self) -> None:
"Global initial value is expected to be a "
"dense type or an unit attribute"
)
if self.alignment is not None:
assert isinstance(self.alignment, IntegerAttr)
alignment_value = self.alignment.value.data
# Alignment has to be a power of two
if not (is_power_of_two(alignment_value)):
raise VerifyException(
f"Alignment attribute {alignment_value} is not a power of 2"
)

@staticmethod
def get(
sym_name: StringAttr,
sym_type: Attribute,
initial_value: Attribute,
sym_visibility: StringAttr = StringAttr("private"),
alignment: int | IntegerAttr[IntegerType] | None = None,
) -> Global:
if isinstance(alignment, int):
alignment = IntegerAttr.from_int_and_width(alignment, 64)

return Global.build(
properties={
"sym_name": sym_name,
"type": sym_type,
"initial_value": initial_value,
"sym_visibility": sym_visibility,
"alignment": alignment,
}
)

Expand Down
8 changes: 8 additions & 0 deletions xdsl/utils/bitwise_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ def convert_u32_to_f32(value: int) -> float:
raw_int = ctypes.c_uint32(value)
raw_float = ctypes.c_float.from_address(ctypes.addressof(raw_int)).value
return raw_float


def is_power_of_two(value: int) -> bool:
"""
Return True if an integer is a power of two.
Powers of two have only one bit set to one
"""
return (value > 0) and (value.bit_count() == 1)

0 comments on commit 82b5663

Please sign in to comment.