Skip to content

Commit

Permalink
Add scalar support to signature reflection and iree-run-mlir
Browse files Browse the repository at this point in the history
iree-run-mlir still only has support for i32, since lots of other parts don't play nicely with any other scalar type.

This at least makes it easier to play with small example IR, whereas tensors generate pages of dispatch goo.

Note: I fixed the split-input-file in materialize_exported_reflection.mlir. It was using four dashes instead of 5, which just fails silently. Man I do not like that testing functionality...

Also includes some extra log lines in error paths in the IREE HAL.

Closes iree-org#551

PiperOrigin-RevId: 291846417
  • Loading branch information
GMNGeoffrey authored and copybara-github committed Jan 28, 2020
1 parent 6d61b73 commit 654dbd1
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 50 deletions.
10 changes: 7 additions & 3 deletions docs/function_abi.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ signature ::= 'I' length-prefixed(type-sequence)
'R' length-prefixed(type-sequence)
type-sequence ::= (arg-result-type)*
arg-result-type ::= buffer-type | ref-object-type | unrecognized-type
buffer-type ::= 'B' length-prefixed(scalar-type? dim*)
scalar-type ::= 't' (
arg-result-type ::= buffer-type
| ref-object-type
| scalar-type
| unrecognized-type
buffer-type ::= 'B' length-prefixed(scalar-element-type? dim*)
scalar-type ::= 'S' length-prefixed(scalar-element-type?)
scalar-element-type ::= 't' (
'0' # IEEE float32 (default if not specified)
| '1' # IEEE float16
| '2' # IEEE float64
Expand Down
23 changes: 22 additions & 1 deletion iree/base/signature_mangle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ void RawSignatureMangler::AddShapedNDBuffer(
item_builder.AppendTo(builder_, 'B');
}

void RawSignatureMangler::AddScalar(AbiConstants::ScalarType type) {
SignatureBuilder item_builder;
// Fields:
// 't': scalar type code
if (static_cast<unsigned>(type) != 0) {
item_builder.Integer(static_cast<unsigned>(type), 't');
}
item_builder.AppendTo(builder_, 'S');
}

// -----------------------------------------------------------------------------
// RawSignatureParser
// -----------------------------------------------------------------------------
Expand All @@ -199,9 +209,20 @@ void RawSignatureParser::Description::ToString(std::string& s) const {
absl::StrAppend(&s, "]>");
break;
}
case Type::kRefObject:
case Type::kRefObject: {
absl::StrAppend(&s, "RefObject<?>");
break;
}
case Type::kScalar: {
const char* type_name = "!BADTYPE!";
unsigned type_u = static_cast<unsigned>(scalar.type);
if (type_u >= 0 && type_u <= AbiConstants::kScalarTypeNames.size()) {
type_name =
AbiConstants::kScalarTypeNames[static_cast<unsigned>(type_u)];
}
absl::StrAppend(&s, type_name);
break;
}
default:
absl::StrAppend(&s, "!UNKNOWN!");
}
Expand Down
36 changes: 36 additions & 0 deletions iree/base/signature_mangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class RawSignatureMangler {
void AddShapedNDBuffer(AbiConstants::ScalarType element_type,
absl::Span<const int> shape);

void AddScalar(AbiConstants::ScalarType type);

const SignatureBuilder& builder() const { return builder_; }

private:
Expand All @@ -196,6 +198,7 @@ class RawSignatureParser {
enum class Type {
kBuffer = 0,
kRefObject = 1,
kScalar = 2,
};

// Description of an input or result.
Expand All @@ -211,6 +214,10 @@ class RawSignatureParser {
struct {
AbiConstants::ScalarType scalar_type;
} buffer;
// Further details for Type == kScalar.
struct {
AbiConstants::ScalarType type;
} scalar;
};

// Human readable description.
Expand Down Expand Up @@ -268,6 +275,11 @@ class RawSignatureParser {
return;
}
break;
case 'S':
if (!FillScalar(d, SignatureParser(item_parser.nested()))) {
return;
}
break;
default:
SetError("Unrecognized raw tag");
return;
Expand All @@ -278,6 +290,30 @@ class RawSignatureParser {
}
}

bool FillScalar(Description& d, SignatureParser p) {
d.type = Type::kScalar;
d.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat32; // Default
while (!p.end_or_error()) {
switch (p.tag()) {
case 't':
if (p.ival() < 0 ||
p.ival() >
static_cast<int>(AbiConstants::ScalarType::kMaxScalarType)) {
SetError("Illegal ScalarType code");
return false;
}
d.buffer.scalar_type =
static_cast<AbiConstants::ScalarType>(p.ival());
break;
default:
SetError("Unrecognized scalar field tag");
return false;
}
p.Next();
}
return true;
}

bool FillBuffer(Description& d, SignatureParser p) {
d.type = Type::kBuffer;
d.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat32; // Default
Expand Down
32 changes: 30 additions & 2 deletions iree/base/signature_mangle_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ TEST(RawSignatureManglerTest, FullBuffer) {
EXPECT_EQ("B13!t2d-1d128d64", sm.builder().encoded());
}

TEST(RawSignatureManglerTest, DefaultScalar) {
RawSignatureMangler sm;
sm.AddScalar(AbiConstants::ScalarType::kIeeeFloat32);
EXPECT_EQ("S1!", sm.builder().encoded());
}

TEST(RawSignatureManglerTest, FullScalar) {
RawSignatureMangler sm;
sm.AddScalar(AbiConstants::ScalarType::kSint32);
EXPECT_EQ("S3!t6", sm.builder().encoded());
}

TEST(RawSignatureManglerTest, AnyRef) {
RawSignatureMangler sm;
sm.AddAnyReference();
Expand Down Expand Up @@ -282,25 +294,41 @@ TEST(RawSignatureParserTest, DynamicNdArrayBuffer) {
EXPECT_EQ("(Buffer<float32[?x128x64]>) -> (Buffer<sint32[?x8x64]>)", *s);
}

TEST(RawSignatureParserTest, Scalar) {
RawSignatureMangler inputs;
inputs.AddScalar(AbiConstants::ScalarType::kSint32);
RawSignatureMangler results;
results.AddScalar(AbiConstants::ScalarType::kIeeeFloat64);

auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results);
EXPECT_EQ("I6!S3!t6R6!S3!t2", sig.encoded());

RawSignatureParser p;
auto s = p.FunctionSignatureToString(sig.encoded());
ASSERT_TRUE(s) << *p.GetError();
EXPECT_EQ("(sint32) -> (float64)", *s);
}

TEST(RawSignatureParserTest, AllTypes) {
RawSignatureMangler inputs;
inputs.AddAnyReference();
std::vector<int> dims = {-1, 128, 64};
inputs.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32,
absl::MakeSpan(dims));
inputs.AddScalar(AbiConstants::ScalarType::kSint32);
RawSignatureMangler results;
std::vector<int> dims2 = {32, -1, 64};
results.AddShapedNDBuffer(AbiConstants::ScalarType::kUint64,
absl::MakeSpan(dims2));

auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results);
EXPECT_EQ("I18!O1!B11!d-1d128d64R17!B13!t11d32d-1d64", sig.encoded());
EXPECT_EQ("I23!O1!B11!d-1d128d64S3!t6R17!B13!t11d32d-1d64", sig.encoded());

RawSignatureParser p;
auto s = p.FunctionSignatureToString(sig.encoded());
ASSERT_TRUE(s) << *p.GetError();
EXPECT_EQ(
"(RefObject<?>, Buffer<float32[?x128x64]>) -> "
"(RefObject<?>, Buffer<float32[?x128x64]>, sint32) -> "
"(Buffer<uint64[32x?x64]>)",
*s);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,28 @@ llvm::Optional<RawSignatureMangler> mangleTensorType(TensorType t) {
}

RawSignatureMangler mangler;
// Tensors map to buffers in the ABI.
mangler.AddShapedNDBuffer(*scalarType, absl::MakeConstSpan(dims));
return mangler;
}

llvm::Optional<RawSignatureMangler> mangleScalarType(Type t) {
auto mappedType = mapScalarType(t);
if (!mappedType) return llvm::None;
RawSignatureMangler mangler;
mangler.AddScalar(*mappedType);
return mangler;
}

StringAttr mangleType(Builder builder, Type type, char tag) {
SignatureBuilder fBuilder;
auto mangledType = mangleScalarType(type);
if (auto tensorType = type.dyn_cast<TensorType>()) {
// Tensors map to buffers in the ABI.
auto mangledTensor = mangleTensorType(tensorType);
if (!mangledTensor) return nullptr;
mangledTensor->builder().AppendTo(fBuilder, tag);
return builder.getStringAttr(fBuilder.encoded());
mangledType = mangleTensorType(tensorType);
}

return nullptr;
if (!mangledType) return nullptr;
mangledType->builder().AppendTo(fBuilder, tag);
return builder.getStringAttr(fBuilder.encoded());
}

StringAttr unrecognizedTypeAttr(Builder builder, char tag) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-materialize-exported-reflection %s | IreeFileCheck %s

// ----
// -----
// CHECK-LABEL: func @notExported
// CHECK-NOT: iree.reflection
func @notExported(%arg0 : tensor<4x4xi64>) -> tensor<4x4xi64> {
return %arg0 : tensor<4x4xi64>
}

// ----
// -----
// CHECK-LABEL: func @exportedTensor
// CHECK-SAME: iree.reflection = {f_partial = "I10!B7!t7d4d4"}
// CHECK-SAME: iree.reflection = {f_partial = "I10!B7!t7d5d5"}
Expand All @@ -18,7 +32,7 @@ func @exportedTensor(%arg0 : tensor<4x4xi64>, %arg1 : tensor<5x5xi64>) -> tensor
return %arg1 : tensor<5x5xi64>
}

// ----
// -----
// CHECK-LABEL: func @unrecognizedArgument
// CHECK-SAME: iree.reflection = {f_partial = "I4!U1!"}
// expected-warning @+1 {{Argument #0 of function unrecognizedArgument is not a recognized public ABI type and the function may not be invokable by standard tools}}
Expand All @@ -28,7 +42,7 @@ func @unrecognizedArgument(%arg0 : i1) -> ()
return
}

// ----
// -----
// CHECK-LABEL: func @unrecognizedResult
// CHECK-SAME: iree.reflection = {f_partial = "R4!U1!"}
// expected-warning @+1 {{Result #0 of function unrecognizedResult is not a recognized public ABI type and the function may not be invokable by standard tools}}
Expand All @@ -39,74 +53,83 @@ func @unrecognizedResult() -> (i1)
return %0 : i1
}

// ----
// -----
// CHECK-LABEL: func @dynamicDim
// CHECK-SAME: iree.reflection = {f_partial = "I11!B8!t7d-1d4"}
func @dynamicDim(%arg0 : tensor<?x4xi64>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @scalari32
// CHECK-SAME: iree.reflection = {f_partial = "I6!S3!t6"}
func @scalari32(%arg0 : i32) -> () attributes {iree.module.export}
{
return
}

// -----
// CHECK-LABEL: func @tensorFloat32
// CHECK-SAME: iree.reflection = {f_partial = "I6!B3!d1"}
func @tensorFloat32(%arg0 : tensor<1xf32>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorFloat64
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t2d1"}
func @tensorFloat64(%arg0 : tensor<1xf64>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorFloat16
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t1d1"}
func @tensorFloat16(%arg0 : tensor<1xf16>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorBfloat16
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t3d1"}
func @tensorBfloat16(%arg0 : tensor<1xbf16>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorSint8
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t4d1"}
func @tensorSint8(%arg0 : tensor<1xi8>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorSint16
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t5d1"}
func @tensorSint16(%arg0 : tensor<1xi16>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorSint32
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t6d1"}
func @tensorSint32(%arg0 : tensor<1xi32>) -> () attributes {iree.module.export}
{
return
}

// ----
// -----
// CHECK-LABEL: func @tensorSint64
// CHECK-SAME: iree.reflection = {f_partial = "I8!B5!t7d1"}
func @tensorSint64(%arg0 : tensor<1xi64>) -> () attributes {iree.module.export}
{
return
}

6 changes: 4 additions & 2 deletions iree/hal/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,17 @@ IREE_API_EXPORT iree_status_t iree_hal_buffer_map(
IREE_TRACE_SCOPE0("iree_hal_buffer_map");

if (!out_mapped_memory) {
LOG(ERROR) << "output mapped memory not set";
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory));

auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
if (!buffer_handle) {
if (!buffer) {
LOG(ERROR) << "buffer not set";
return IREE_STATUS_INVALID_ARGUMENT;
}

auto* buffer_handle = reinterpret_cast<Buffer*>(buffer);
IREE_API_ASSIGN_OR_RETURN(
auto mapping, buffer_handle->MapMemory<uint8_t>(
static_cast<MemoryAccessBitfield>(memory_access),
Expand Down
1 change: 1 addition & 0 deletions iree/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ cc_binary(
"//iree/modules/hal",
"//iree/vm",
"//iree/vm:bytecode_module",
"//iree/vm:value",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
Expand Down
Loading

0 comments on commit 654dbd1

Please sign in to comment.