Skip to content

Commit

Permalink
[gccjit] finish type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Oct 28, 2024
1 parent 559ac1d commit ecce253
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
5 changes: 3 additions & 2 deletions include/mlir-gccjit/Conversion/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
#include "libgccjit.h"
#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::gccjit {
class GCCJITTypeConverter : public TypeConverter {
gcc_jit_context *tmpContext;
llvm::DenseMap<mlir::Type, gccjit::StructType> packedTypes;

public:
GCCJITTypeConverter();
Expand Down Expand Up @@ -55,7 +56,7 @@ class GCCJITTypeConverter : public TypeConverter {
getUnrankedMemrefDescriptorType(mlir::UnrankedMemRefType type);

private:
Type convertAndPackTypesIfNonSingleton(TypeRange types);
Type convertAndPackTypesIfNonSingleton(TypeRange types, FunctionType name);
};
} // namespace mlir::gccjit
#endif // MLIR_GCCJIT_CONVERSION_TYPECONVERTER_H
85 changes: 74 additions & 11 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
#include "libgccjit.h"
#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"

using namespace mlir;
using namespace mlir::gccjit;

GCCJITTypeConverter::GCCJITTypeConverter() : TypeConverter() {
tmpContext = gcc_jit_context_acquire();
GCCJITTypeConverter::GCCJITTypeConverter() : TypeConverter(), packedTypes() {
addConversion([&](mlir::IndexType type) { return convertIndexType(type); });
addConversion(
[&](mlir::IntegerType type) { return convertIntegerType(type); });
Expand All @@ -36,10 +37,8 @@ GCCJITTypeConverter::GCCJITTypeConverter() : TypeConverter() {
[&](mlir::MemRefType type) { return getMemrefDescriptorType(type); });
}

GCCJITTypeConverter::~GCCJITTypeConverter() {
if (tmpContext)
gcc_jit_context_release(tmpContext);
}
// Nothing to do for now
GCCJITTypeConverter::~GCCJITTypeConverter() {}

gccjit::IntType GCCJITTypeConverter::convertIndexType(mlir::IndexType type) {
return IntType::get(type.getContext(), GCC_JIT_TYPE_SIZE_T);
Expand Down Expand Up @@ -131,7 +130,7 @@ GCCJITTypeConverter::convertFunctionType(mlir::FunctionType type,
argTypes.reserve(type.getNumInputs());
if (convertTypes(type.getInputs(), argTypes).failed())
return {};
auto resultType = convertAndPackTypesIfNonSingleton(type.getResults());
auto resultType = convertAndPackTypesIfNonSingleton(type.getResults(), type);
return FuncType::get(type.getContext(), argTypes, resultType, isVarArg);
}

Expand All @@ -144,16 +143,80 @@ GCCJITTypeConverter::convertFunctionTypeAsPtr(mlir::FunctionType type,

gccjit::StructType
GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) {
llvm_unreachable("NYI");
auto &cached = packedTypes[type];
if (!cached) {
auto name = Twine("__memref_")
.concat(Twine(
reinterpret_cast<uintptr_t>(type.getAsOpaquePointer())))
.str();
auto nameAttr = StringAttr::get(type.getContext(), name);
auto elementType = convertType(type.getElementType());
auto elementPtrType = PointerType::get(type.getContext(), elementType);
auto indexType = IntType::get(type.getContext(), GCC_JIT_TYPE_SIZE_T);
auto rank = type.getRank();
auto dimOrStrideType =
gccjit::ArrayType::get(type.getContext(), indexType, rank);
SmallVector<Attribute> fields;
for (auto [idx, field] : llvm::enumerate(
ArrayRef<Type>{elementPtrType, elementPtrType, indexType,
dimOrStrideType, dimOrStrideType})) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(type.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
cached = StructType::get(type.getContext(), nameAttr, fieldsAttr);
}
return cached;
}

gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType(
mlir::UnrankedMemRefType type) {
llvm_unreachable("NYI");
auto &cached = packedTypes[type];
if (!cached) {
auto name = Twine("__unranked_memref_")
.concat(Twine(
reinterpret_cast<uintptr_t>(type.getAsOpaquePointer())))
.str();
auto nameAttr = StringAttr::get(type.getContext(), name);
auto indexType = IntType::get(type.getContext(), GCC_JIT_TYPE_SIZE_T);
auto opaquePtrType = PointerType::get(
type.getContext(), IntType::get(type.getContext(), GCC_JIT_TYPE_VOID));
SmallVector<Attribute> fields;
for (auto [idx, field] :
llvm::enumerate(ArrayRef<Type>{indexType, opaquePtrType})) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(type.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
cached = StructType::get(type.getContext(), nameAttr, fieldsAttr);
}
return cached;
}

Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(TypeRange types) {
Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(TypeRange types,
FunctionType func) {
if (types.size() == 0)
return VoidType::get(func.getContext());
if (types.size() == 1)
return types.front();
llvm_unreachable("NYI");
gccjit::StructType &cached = packedTypes[func];
if (!cached) {
auto name = Twine("__retpack_")
.concat(Twine(
reinterpret_cast<uintptr_t>(func.getAsOpaquePointer())))
.str();
SmallVector<Attribute> fields;
for (auto [idx, type] : llvm::enumerate(types)) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(func.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0));
}
auto nameAttr = StringAttr::get(func.getContext(), name);
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
auto structType = StructType::get(func.getContext(), nameAttr, fieldsAttr);
cached = structType;
}
return cached;
}

0 comments on commit ecce253

Please sign in to comment.