Skip to content

Commit

Permalink
Add utility function for EMITC_TYPEDEF_STRUCT (iree-org#9730)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-camp authored Jul 20, 2022
1 parent d6dc9b7 commit 59c6678
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1563,10 +1563,10 @@ class ExportOpConversion : public OpConversionPattern<IREE::VM::ExportOp> {
return failure();
}

auto generateStructBody = [&funcOp](
ArrayRef<Type> types,
StringRef prefix) -> FailureOr<std::string> {
std::string structBody;
auto generateStructFields = [&funcOp](ArrayRef<Type> types,
StringRef prefix)
-> FailureOr<SmallVector<emitc_builders::StructField>> {
SmallVector<emitc_builders::StructField> result;

for (auto pair : llvm::enumerate(types)) {
Optional<std::string> cType = getCType(pair.value());
Expand All @@ -1575,11 +1575,12 @@ class ExportOpConversion : public OpConversionPattern<IREE::VM::ExportOp> {
"c type in argument struct declaration.";
return failure();
}
structBody += cType.getValue() + " " + prefix.str() +
std::to_string(pair.index()) + ";";

auto fieldName = prefix.str() + std::to_string(pair.index());
result.push_back({cType.getValue(), fieldName});
}

return structBody;
return result;
};

// TODO(simon-camp): Clean up; We generate calls to a macro that defines
Expand All @@ -1588,22 +1589,15 @@ class ExportOpConversion : public OpConversionPattern<IREE::VM::ExportOp> {

// To prevent scoping issues we prefix the struct name with module and
// function name.
auto typedefStruct = [&rewriter, &newFuncOp, &loc](std::string structName,
std::string structBody) {
auto ctx = rewriter.getContext();

auto typedefStruct = [&rewriter, &newFuncOp, &loc](
std::string structName,
ArrayRef<emitc_builders::StructField> fields) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(newFuncOp.getOperation());

rewriter.create<emitc::CallOp>(
/*location=*/loc,
/*type=*/TypeRange{},
/*callee=*/StringAttr::get(ctx, "EMITC_TYPEDEF_STRUCT"),
/*args=*/
ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, structName),
emitc::OpaqueAttr::get(ctx, structBody)}),
/*templateArgs=*/ArrayAttr{},
/*operands=*/ArrayRef<Value>{});
emitc_builders::structDefinition(/*builder=*/rewriter, /*location=*/loc,
/*structName=*/structName,
/*fields=*/fields);
};

FunctionType funcType = vmAnalysis.getValue().get().originalFunctionType;
Expand All @@ -1614,8 +1608,7 @@ class ExportOpConversion : public OpConversionPattern<IREE::VM::ExportOp> {
const bool needArgumentStruct = funcType.getNumInputs() > 0;

if (needArgumentStruct) {
FailureOr<std::string> structBody =
generateStructBody(funcType.getInputs(), "arg");
auto structBody = generateStructFields(funcType.getInputs(), "arg");
if (failed(structBody)) {
return failure();
}
Expand All @@ -1628,8 +1621,7 @@ class ExportOpConversion : public OpConversionPattern<IREE::VM::ExportOp> {
const bool needResultStruct = funcType.getNumResults() > 0;

if (needResultStruct) {
FailureOr<std::string> structBody =
generateStructBody(funcType.getResults(), "res");
auto structBody = generateStructFields(funcType.getResults(), "res");

if (failed(structBody)) {
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ Value arrayElementAddress(OpBuilder builder, Location location, Type type,
.getResult(0);
}

void structDefinition(OpBuilder builder, Location location,
StringRef structName, ArrayRef<StructField> fields) {
std::string structBody;

for (auto &field : fields) {
structBody += field.type + " " + field.name + ";";
}

auto ctx = builder.getContext();

builder.create<emitc::CallOp>(
/*location=*/location, /*type=*/TypeRange{},
/*callee=*/StringAttr::get(ctx, "EMITC_TYPEDEF_STRUCT"), /*args=*/
ArrayAttr::get(ctx, {emitc::OpaqueAttr::get(ctx, structName),
emitc::OpaqueAttr::get(ctx, structBody)}),
/*templateArgs=*/ArrayAttr{}, /*operands=*/ArrayRef<Value>{});
}

Value structMember(OpBuilder builder, Location location, Type type,
StringRef memberName, Value operand) {
auto ctx = builder.getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_EMITCBUILDERS_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_VMTOEMITC_EMITCBUILDERS_H_

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand All @@ -18,9 +19,17 @@ namespace mlir {
namespace iree_compiler {
namespace emitc_builders {

struct StructField {
std::string type;
std::string name;
};

Value arrayElementAddress(OpBuilder builder, Location location, Type type,
IntegerAttr index, Value operand);

void structDefinition(OpBuilder builder, Location location,
StringRef structName, ArrayRef<StructField> fields);

Value structMember(OpBuilder builder, Location location, Type type,
StringRef memberName, Value operand);

Expand Down

0 comments on commit 59c6678

Please sign in to comment.