Skip to content
This repository was archived by the owner on Dec 26, 2022. It is now read-only.

Commit

Permalink
Adding vm::ref<T> support for non-iree_vm_ref_t ref types. (iree-org#…
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored Nov 9, 2022
1 parent da03073 commit c149d61
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 134 deletions.
28 changes: 15 additions & 13 deletions compiler/src/iree/compiler/ConstEval/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/hal/drivers/local_task/registration/driver_module.h"
#include "iree/modules/hal/module.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"

Expand Down Expand Up @@ -89,17 +88,17 @@ CompiledBinary::CompiledBinary() {}
CompiledBinary::~CompiledBinary() {}

void CompiledBinary::deinitialize() {
iree_vm_module_release(hal_module);
iree_vm_module_release(main_module);
iree_vm_context_release(context);
iree_hal_device_release(device);
hal_module.reset();
main_module.reset();
context.reset();
device.reset();
}

LogicalResult CompiledBinary::invokeNullary(Location loc, StringRef name,
ResultsCallback callback) {
iree_vm_function_t function;
if (auto status = iree_vm_module_lookup_function_by_name(
main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{name.data(), name.size()}, &function)) {
iree_status_ignore(status);
return emitError(loc) << "internal error evaling constant: func '" << name
Expand All @@ -114,7 +113,7 @@ LogicalResult CompiledBinary::invokeNullary(Location loc, StringRef name,
iree_allocator_system(), &outputs));

if (auto status =
iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(), outputs.get(),
iree_allocator_system())) {
std::string message;
Expand Down Expand Up @@ -272,19 +271,22 @@ void CompiledBinary::initialize(void* data, size_t length) {
iree_hal_driver_release(driver);

// Create hal module.
IREE_CHECK_OK(iree_hal_module_create(runtime.instance, device,
IREE_CHECK_OK(iree_hal_module_create(runtime.instance.get(), device.get(),
IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));

// Bytecode module.
IREE_CHECK_OK(iree_vm_bytecode_module_create(
runtime.instance, iree_make_const_byte_span(data, length),
runtime.instance.get(), iree_make_const_byte_span(data, length),
iree_allocator_null(), iree_allocator_system(), &main_module));

// Context.
std::array<iree_vm_module_t*, 2> modules = {hal_module, main_module};
std::array<iree_vm_module_t*, 2> modules = {
hal_module.get(),
main_module.get(),
};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
runtime.instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
modules.data(), iree_allocator_system(), &context));
}

Expand All @@ -308,11 +310,11 @@ Runtime::Runtime() {
iree_hal_driver_registry_allocate(iree_allocator_system(), &registry));
IREE_CHECK_OK(iree_hal_local_task_driver_module_register(registry));
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
IREE_CHECK_OK(iree_hal_module_register_all_types(instance));
IREE_CHECK_OK(iree_hal_module_register_all_types(instance.get()));
}

Runtime::~Runtime() {
iree_vm_instance_release(instance);
instance.reset();
iree_hal_driver_registry_free(registry);
}

Expand Down
11 changes: 6 additions & 5 deletions compiler/src/iree/compiler/ConstEval/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -47,10 +48,10 @@ class CompiledBinary {
void deinitialize();
Attribute convertVariantToAttribute(Location loc, iree_vm_variant_t& variant);

iree_hal_device_t* device = nullptr;
iree_vm_module_t* hal_module = nullptr;
iree_vm_module_t* main_module = nullptr;
iree_vm_context_t* context = nullptr;
iree::vm::ref<iree_hal_device_t> device;
iree::vm::ref<iree_vm_module_t> hal_module;
iree::vm::ref<iree_vm_module_t> main_module;
iree::vm::ref<iree_vm_context_t> context;
};

// An in-memory compiled binary and accessors for working with it.
Expand All @@ -70,7 +71,7 @@ class Runtime {
static Runtime& getInstance();

iree_hal_driver_registry_t* registry = nullptr;
iree_vm_instance_t* instance = nullptr;
iree::vm::ref<iree_vm_instance_t> instance;

private:
Runtime();
Expand Down
14 changes: 7 additions & 7 deletions runtime/src/iree/modules/check/check_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ class CheckTest : public ::testing::Test {
/*outputs=*/nullptr, iree_allocator_system());
}

iree_status_t Invoke(const char* function_name,
std::vector<iree_vm_value_t> args) {
iree_status_t InvokeValue(const char* function_name,
std::vector<iree_vm_value_t> args) {
IREE_RETURN_IF_ERROR(
iree_vm_list_create(/*element_type=*/nullptr, args.size(),
iree_allocator_system(), &inputs_));
Expand Down Expand Up @@ -216,28 +216,28 @@ iree_vm_module_t* CheckTest::check_module_ = nullptr;
iree_vm_module_t* CheckTest::hal_module_ = nullptr;

TEST_F(CheckTest, ExpectTrueSuccess) {
IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(1)}));
IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
}

TEST_F(CheckTest, ExpectTrueFailure) {
EXPECT_NONFATAL_FAILURE(
IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(0)})),
IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(0)})),
"Expected 0 to be nonzero");
}

TEST_F(CheckTest, ExpectFalseSuccess) {
IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(0)}));
IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(0)}));
}

TEST_F(CheckTest, ExpectFalseFailure) {
EXPECT_NONFATAL_FAILURE(
IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(1)})),
IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(1)})),
"Expected 1 to be zero");
}

TEST_F(CheckTest, ExpectFalseNotOneFailure) {
EXPECT_NONFATAL_FAILURE(
IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(42)})),
IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(42)})),
"Expected 42 to be zero");
}

Expand Down
3 changes: 3 additions & 0 deletions runtime/src/iree/vm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/base/api.h"
#include "iree/vm/instance.h"
#include "iree/vm/module.h"
#include "iree/vm/ref.h"
#include "iree/vm/stack.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -123,4 +124,6 @@ IREE_API_EXPORT iree_status_t iree_vm_context_notify(iree_vm_context_t* context,
} // extern "C"
#endif // __cplusplus

IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_context, iree_vm_context_t);

#endif // IREE_VM_CONTEXT_H_
3 changes: 3 additions & 0 deletions runtime/src/iree/vm/instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define IREE_VM_INSTANCE_H_

#include "iree/base/api.h"
#include "iree/vm/ref.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -47,4 +48,6 @@ iree_vm_instance_allocator(iree_vm_instance_t* instance);
} // extern "C"
#endif // __cplusplus

IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_instance, iree_vm_instance_t);

#endif // IREE_VM_INSTANCE_H_
3 changes: 3 additions & 0 deletions runtime/src/iree/vm/invocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/vm/context.h"
#include "iree/vm/list.h"
#include "iree/vm/module.h"
#include "iree/vm/ref.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -344,4 +345,6 @@ IREE_API_EXPORT void iree_vm_invocation_cancel(
} // extern "C"
#endif // __cplusplus

IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_invocation, iree_vm_invocation_t);

#endif // IREE_VM_INVOCATION_H_
3 changes: 3 additions & 0 deletions runtime/src/iree/vm/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "iree/base/api.h"
#include "iree/base/internal/atomics.h"
#include "iree/base/string_builder.h"
#include "iree/vm/ref.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -552,4 +553,6 @@ iree_vm_function_get_attr(iree_vm_function_t function, iree_host_size_t index,
} // extern "C"
#endif // __cplusplus

IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_module, iree_vm_module_t);

#endif // IREE_VM_MODULE_H_
35 changes: 4 additions & 31 deletions runtime/src/iree/vm/ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,37 +236,6 @@ IREE_API_EXPORT bool iree_vm_ref_equal(iree_vm_ref_t* lhs, iree_vm_ref_t* rhs);
// Type adapter utilities for interfacing with the VM
//===----------------------------------------------------------------------===//

#ifdef __cplusplus
namespace iree {
namespace vm {
template <typename T>
struct ref_type_descriptor {
static const iree_vm_ref_type_descriptor_t* get();
};
} // namespace vm
} // namespace iree
#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T) \
namespace iree { \
namespace vm { \
template <> \
struct ref_type_descriptor<T> { \
static const iree_vm_ref_type_descriptor_t* get() { \
return name##_get_descriptor(); \
} \
}; \
} \
}

#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor) \
descriptor.type_name = iree_make_cstring_view(name); \
descriptor.offsetof_counter = type::offsetof_counter(); \
descriptor.destroy = type::DirectDestroy; \
IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
#else
#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
#endif // __cplusplus

// TODO(benvanik): make these macros standard/document them.
#define IREE_VM_DECLARE_TYPE_ADAPTERS(name, T) \
IREE_API_EXPORT iree_vm_ref_t name##_retain_ref(T* value); \
Expand Down Expand Up @@ -330,6 +299,10 @@ struct ref_type_descriptor {
// Optional C++ iree::vm::ref<T> wrapper.
#ifdef __cplusplus
#include "iree/vm/ref_cc.h"
#else
#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
#define IREE_VM_DECLARE_CC_TYPE_ADAPTERS(name, T)
#endif // __cplusplus

#endif // IREE_VM_REF_H_
Loading

0 comments on commit c149d61

Please sign in to comment.