From f022d29ad6d9c9ba793e08e529c0472a6b43af12 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 15 Aug 2023 18:51:46 -0700 Subject: [PATCH] Reworking constant upload with a HAL file API. (#14665) This changes the compiler-generated map/staging IR into file reads that are managed by the runtime HAL targets. In this initial version only one type of file is supported that wraps host memory and exposes it via the file API. For targets that don't natively support file streaming a utility for asynchronously streaming files using an iree_loop_t is available and all targets currently use that. In the future targets like CUDA/Metal/Direct3D12 can use cuFile/MTLIOCommandBuffer/DirectStorage to directly stream file contents from/to disk. Future changes will expand the iree_hal_file_t vtable with what can be reliably implemented or guarded with feature flags. When any form of asynchronous loop is implemented (such as a task-system fiber loop) we'll be able to perform overlapped copies to reduce total transfer latency by letting the CPU and GPU do their respective staging operations concurrently. Today only synchronous loops are used so it's fully serialized. Now that the compiler doesn't emit staging with iree_allocator_allocate_buffer's initial_data that can be removed in #14605 to remove the primary use of the existing buffer_transfer util. The remaining places using it can be switched to memory file streaming instead such as numpy IO and support much larger content. This is a breaking HAL change and all existing VMFBs will fail to load on newer runtimes. CPU and CUDA targets benefit from this as should all other backends though some may need to have zero-copy paths added - they'll at least benefit from streaming and bounded staging buffer sizes. Vulkan has been optimized a bit and the follow-on sparse bindings support improves it further. Vulkan dedicated GPU w/ 1GB model before this change (1024MB + 1024MB of staging, even when mmapped): ![image](https://github.com/openxla/iree/assets/75337/62b4b144-efe6-4116-8192-3f7bbfbbfaeb) Vulkan dedicated GPU w/ 1GB model now (1024MB + ~64MB of staging used as streamed from a memory mapped file): ![image](https://github.com/openxla/iree/assets/75337/4e44a186-9c11-4d83-951b-b6bc1ad6e3f2) Vulkan integrated GPU w/ 1GB model now (1024MB of host memory imported and used directly with no copies): ![image](https://github.com/openxla/iree/assets/75337/0d036ede-0aac-45be-b069-6bb8d6e06cb6) Progress on #14607. --- .../HALToVM/ConvertAllocatorOps.cpp | 49 +- .../Conversion/HALToVM/ConvertDeviceOps.cpp | 4 + .../HALToVM/ConvertExperimentalOps.cpp | 2 + .../HALToVM/test/allocator_ops.mlir | 24 +- .../Conversion/HALToVM/test/device_ops.mlir | 35 + .../HAL/Conversion/StreamToHAL/BUILD.bazel | 1 + .../HAL/Conversion/StreamToHAL/CMakeLists.txt | 1 + .../HAL/Conversion/StreamToHAL/Patterns.cpp | 171 +++- .../Conversion/StreamToHAL/test/BUILD.bazel | 1 + .../StreamToHAL/test/CMakeLists.txt | 1 + .../Conversion/StreamToHAL/test/file_ops.mlir | 41 + .../StreamToHAL/test/resource_ops.mlir | 19 +- .../iree/compiler/Dialect/HAL/IR/HALBase.td | 52 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 42 +- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 186 +++- .../iree/compiler/Dialect/HAL/IR/HALTypes.cpp | 14 +- .../iree/compiler/Dialect/HAL/IR/HALTypes.h | 9 +- .../Dialect/HAL/IR/test/allocator_ops.mlir | 21 +- .../Dialect/HAL/IR/test/device_ops.mlir | 76 +- .../Dialect/HAL/IR/test/experimental_ops.mlir | 29 + .../Dialect/HAL/Transforms/ConvertToHAL.cpp | 4 +- .../Transforms/DumpExecutableBenchmarks.cpp | 5 +- .../test/dump_executable_benchmarks.mlir | 2 +- .../compiler/Dialect/HAL/hal.imports.mlir | 62 +- .../compiler/Dialect/Stream/IR/StreamBase.td | 22 + .../Dialect/Stream/IR/StreamOpFolders.cpp | 28 +- .../compiler/Dialect/Stream/IR/StreamOps.cpp | 64 +- .../compiler/Dialect/Stream/IR/StreamOps.td | 222 ++++- .../Dialect/Stream/IR/test/BUILD.bazel | 1 + .../Dialect/Stream/IR/test/CMakeLists.txt | 1 + .../Dialect/Stream/IR/test/file_ops.mlir | 37 + .../Dialect/Stream/IR/test/resource_ops.mlir | 11 - .../Stream/Transforms/PackConstants.cpp | 322 ++---- .../Transforms/test/dump_statistics.mlir | 22 +- .../Transforms/test/pack_constants.mlir | 125 ++- .../Dialect/Util/IR/UtilInterfaces.td | 7 +- .../Conversion/StreamToHALInline/Patterns.cpp | 23 +- .../StreamToHALInline/test/resource_ops.mlir | 16 - experimental/cuda2/CMakeLists.txt | 2 + experimental/cuda2/cuda_allocator.c | 39 +- experimental/cuda2/cuda_buffer.c | 8 +- experimental/cuda2/cuda_device.c | 67 +- experimental/cuda2/event_semaphore.c | 13 +- experimental/rocm/CMakeLists.txt | 2 + experimental/rocm/rocm_device.c | 64 ++ experimental/webgpu/BUILD.bazel | 2 + experimental/webgpu/CMakeLists.txt | 2 + experimental/webgpu/webgpu_device.c | 66 ++ runtime/bindings/python/tests/hal_test.py | 12 +- runtime/src/iree/base/loop.h | 2 +- runtime/src/iree/hal/BUILD.bazel | 2 + runtime/src/iree/hal/CMakeLists.txt | 2 + runtime/src/iree/hal/allocator_heap.c | 6 +- runtime/src/iree/hal/api.h | 1 + runtime/src/iree/hal/buffer.c | 10 +- runtime/src/iree/hal/cts/CMakeLists.txt | 13 + runtime/src/iree/hal/cts/file_test.h | 139 +++ runtime/src/iree/hal/device.c | 97 ++ runtime/src/iree/hal/device.h | 65 ++ runtime/src/iree/hal/drivers/cuda/BUILD.bazel | 2 + .../src/iree/hal/drivers/cuda/CMakeLists.txt | 2 + .../iree/hal/drivers/cuda/cuda_allocator.c | 37 +- .../src/iree/hal/drivers/cuda/cuda_buffer.c | 8 +- .../src/iree/hal/drivers/cuda/cuda_device.c | 67 +- .../iree/hal/drivers/local_sync/BUILD.bazel | 2 + .../hal/drivers/local_sync/CMakeLists.txt | 2 + .../hal/drivers/local_sync/cts/CMakeLists.txt | 4 +- .../iree/hal/drivers/local_sync/sync_device.c | 64 ++ .../hal/drivers/local_sync/sync_semaphore.c | 13 +- .../iree/hal/drivers/local_task/BUILD.bazel | 2 + .../hal/drivers/local_task/CMakeLists.txt | 2 + .../iree/hal/drivers/local_task/task_device.c | 64 ++ .../hal/drivers/local_task/task_semaphore.c | 13 +- .../src/iree/hal/drivers/metal/CMakeLists.txt | 2 + .../src/iree/hal/drivers/metal/metal_device.m | 58 ++ .../src/iree/hal/drivers/metal/shared_event.m | 10 +- .../src/iree/hal/drivers/vulkan/BUILD.bazel | 2 + .../iree/hal/drivers/vulkan/CMakeLists.txt | 2 + .../hal/drivers/vulkan/cts/CMakeLists.txt | 2 + .../drivers/vulkan/direct_command_buffer.cc | 18 + .../hal/drivers/vulkan/native_allocator.cc | 2 +- .../iree/hal/drivers/vulkan/native_buffer.cc | 8 +- .../hal/drivers/vulkan/native_semaphore.cc | 24 +- runtime/src/iree/hal/drivers/vulkan/tracing.h | 12 +- .../iree/hal/drivers/vulkan/vma_allocator.cc | 8 +- .../iree/hal/drivers/vulkan/vulkan_device.cc | 69 +- runtime/src/iree/hal/file.c | 36 + runtime/src/iree/hal/file.h | 166 ++++ runtime/src/iree/hal/semaphore.c | 1 - runtime/src/iree/hal/semaphore.h | 24 + runtime/src/iree/hal/utils/BUILD.bazel | 22 + runtime/src/iree/hal/utils/CMakeLists.txt | 28 + runtime/src/iree/hal/utils/buffer_transfer.c | 4 +- runtime/src/iree/hal/utils/file_transfer.c | 935 ++++++++++++++++++ runtime/src/iree/hal/utils/file_transfer.h | 93 ++ runtime/src/iree/hal/utils/memory_file.c | 354 +++++++ runtime/src/iree/hal/utils/memory_file.h | 73 ++ runtime/src/iree/modules/hal/exports.inl | 8 +- runtime/src/iree/modules/hal/module.c | 225 +++-- runtime/src/iree/modules/hal/types.c | 6 + runtime/src/iree/modules/hal/types.h | 1 + runtime/src/iree/vm/shims.c | 6 +- runtime/src/iree/vm/shims.h | 45 +- tools/iree-e2e-matmul-test.c | 4 +- 104 files changed, 4057 insertions(+), 839 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir create mode 100644 compiler/src/iree/compiler/Dialect/Stream/IR/test/file_ops.mlir create mode 100644 runtime/src/iree/hal/cts/file_test.h create mode 100644 runtime/src/iree/hal/file.c create mode 100644 runtime/src/iree/hal/file.h create mode 100644 runtime/src/iree/hal/utils/file_transfer.c create mode 100644 runtime/src/iree/hal/utils/file_transfer.h create mode 100644 runtime/src/iree/hal/utils/memory_file.c create mode 100644 runtime/src/iree/hal/utils/memory_file.h diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp index 2b2ba05150f9..1c23b2d4934a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp @@ -13,21 +13,20 @@ namespace mlir { namespace iree_compiler { namespace { -class AllocatorAllocateInitializedOpConversion - : public OpConversionPattern { +class AllocatorAllocateOpConversion + : public OpConversionPattern { public: - AllocatorAllocateInitializedOpConversion(TypeConverter &typeConverter, - MLIRContext *context, - SymbolTable &importSymbols) + AllocatorAllocateOpConversion(TypeConverter &typeConverter, + MLIRContext *context, + SymbolTable &importSymbols) : OpConversionPattern(typeConverter, context) { - importOp = importSymbols.lookup( - "hal.allocator.allocate.initialized"); + importOp = + importSymbols.lookup("hal.allocator.allocate"); assert(importOp); } LogicalResult - matchAndRewrite(IREE::HAL::AllocatorAllocateInitializedOp op, - OpAdaptor adaptor, + matchAndRewrite(IREE::HAL::AllocatorAllocateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto callOp = rewriter.replaceOpWithNewOp( op, importOp.getName(), @@ -36,14 +35,13 @@ class AllocatorAllocateInitializedOpConversion }, ArrayRef{ adaptor.getAllocator(), + castToImportType(adaptor.getQueueAffinity(), rewriter.getI64Type(), + rewriter), rewriter.createOrFold( op.getLoc(), op.getMemoryTypesAttr().getInt()), rewriter.createOrFold( op.getLoc(), op.getBufferUsageAttr().getInt()), - adaptor.getSource(), - castToImportType(adaptor.getOffset(), rewriter.getI64Type(), - rewriter), - castToImportType(adaptor.getLength(), rewriter.getI64Type(), + castToImportType(adaptor.getResultSize(), rewriter.getI64Type(), rewriter), }); copyImportAttrs(importOp, callOp); @@ -54,19 +52,18 @@ class AllocatorAllocateInitializedOpConversion mutable IREE::VM::ImportOp importOp; }; -class AllocatorTryMapOpConversion - : public OpConversionPattern { +class AllocatorImportOpConversion + : public OpConversionPattern { public: - AllocatorTryMapOpConversion(TypeConverter &typeConverter, + AllocatorImportOpConversion(TypeConverter &typeConverter, MLIRContext *context, SymbolTable &importSymbols) : OpConversionPattern(typeConverter, context) { - importOp = importSymbols.lookup( - "hal.allocator.map.byte_buffer"); + importOp = importSymbols.lookup("hal.allocator.import"); assert(importOp); } LogicalResult - matchAndRewrite(IREE::HAL::AllocatorTryMapOp op, OpAdaptor adaptor, + matchAndRewrite(IREE::HAL::AllocatorImportOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto callOp = rewriter.create( op.getLoc(), importOp.getName(), @@ -76,6 +73,8 @@ class AllocatorTryMapOpConversion ArrayRef{ adaptor.getAllocator(), rewriter.createOrFold(op.getLoc(), /*try=*/1), + castToImportType(adaptor.getQueueAffinity(), rewriter.getI64Type(), + rewriter), rewriter.createOrFold( op.getLoc(), op.getMemoryTypesAttr().getInt()), rewriter.createOrFold( @@ -88,9 +87,9 @@ class AllocatorTryMapOpConversion }); copyImportAttrs(importOp, callOp); auto result = callOp.getResults().front(); - auto didMap = rewriter.create( + auto didImport = rewriter.create( op.getLoc(), rewriter.getI32Type(), result); - rewriter.replaceOp(op, {didMap, result}); + rewriter.replaceOp(op, {didImport, result}); return success(); } @@ -104,11 +103,9 @@ void populateHALAllocatorToVMPatterns(MLIRContext *context, SymbolTable &importSymbols, TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.insert>( - context, importSymbols, typeConverter, "hal.allocator.allocate"); - patterns.insert( - typeConverter, context, importSymbols); - patterns.insert(typeConverter, context, + patterns.insert(typeConverter, context, + importSymbols); + patterns.insert(typeConverter, context, importSymbols); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index 321143c7692f..86d44458bea6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp @@ -129,6 +129,10 @@ void populateHALDeviceToVMPatterns(MLIRContext *context, context, importSymbols, typeConverter, "hal.device.queue.alloca"); patterns.insert>( context, importSymbols, typeConverter, "hal.device.queue.dealloca"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.read"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.write"); patterns.insert>( context, importSymbols, typeConverter, "hal.device.queue.execute"); patterns.insert>( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp index 2490b676dcf1..3dad2460a926 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExperimentalOps.cpp @@ -17,6 +17,8 @@ void populateHALExperimentalToVMPatterns(MLIRContext *context, RewritePatternSet &patterns) { patterns.insert>( context, importSymbols, typeConverter, "hal.ex.shared_device"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.ex.file.from_memory"); } } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir index 0558a83e68c2..3d47ef590494 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir @@ -2,22 +2,28 @@ // CHECK-LABEL: vm.func private @allocatorAllocate func.func @allocatorAllocate(%arg0 : !hal.allocator) -> !hal.buffer { - // CHECK: %[[SIZE:.+]] = vm.const.i64 1024 - %c1024 = arith.constant 1024 : index - // CHECK: %ref = vm.call @hal.allocator.allocate(%arg0, %c70, %c3075, %[[SIZE]]) : (!vm.ref, i32, i32, i64) -> !vm.ref - %0 = hal.allocator.allocate<%arg0 : !hal.allocator> type("HostLocal") usage("DispatchStorage|Transfer") : !hal.buffer{%c1024} + // CHECK-DAG: %[[SIZE:.+]] = vm.const.i64 1024 + %size = arith.constant 1024 : index + // CHECK-DAG: %[[AFFINITY:.+]] = vm.const.i64 -1 + %affinity = arith.constant -1 : i64 + // CHECK: %ref = vm.call @hal.allocator.allocate(%arg0, %[[AFFINITY]], %c70, %c3075, %[[SIZE]]) : (!vm.ref, i64, i32, i32, i64) -> !vm.ref + %0 = hal.allocator.allocate<%arg0 : !hal.allocator> affinity(%affinity) type("HostLocal") usage("DispatchStorage|Transfer") : !hal.buffer{%size} return %0 : !hal.buffer } // ----- -// CHECK-LABEL: vm.func private @allocatorMapByteBuffer -func.func @allocatorMapByteBuffer(%arg0 : !hal.allocator, %arg1 : !util.buffer) -> !hal.buffer { +// CHECK-LABEL: vm.func private @allocatorImport +func.func @allocatorImport(%arg0 : !hal.allocator, %arg1 : !util.buffer) -> (i1, !hal.buffer) { // CHECK-DAG: %[[OFFSET:.+]] = vm.const.i64 128 %offset = arith.constant 128 : index // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 256 %length = arith.constant 256 : index - // CHECK: = vm.call @hal.allocator.allocate.initialized(%arg0, %c6, %c3, %arg1, %[[OFFSET]], %[[LENGTH]]) : (!vm.ref, i32, i32, !vm.buffer, i64, i64) -> !vm.ref - %buffer = hal.allocator.allocate.initialized<%arg0 : !hal.allocator> source(%arg1 : !util.buffer)[%offset, %length] type("HostVisible|HostCoherent") usage("Transfer") : !hal.buffer - return %buffer : !hal.buffer + // CHECK-DAG: %[[AFFINITY:.+]] = vm.const.i64 -1 + %affinity = arith.constant -1 : i64 + // CHECK: %[[IMPORTED:.+]] = vm.call @hal.allocator.import(%arg0, %c1, %[[AFFINITY]], %c6, %c3, %arg1, %[[OFFSET]], %[[LENGTH]]) : (!vm.ref, i32, i64, i32, i32, !vm.buffer, i64, i64) -> !vm.ref + %did_import, %buffer = hal.allocator.import<%arg0 : !hal.allocator> source(%arg1 : !util.buffer)[%offset, %length] affinity(%affinity) type("HostVisible|HostCoherent") usage("Transfer") : i1, !hal.buffer + // CHECK: %[[DID_IMPORT:.+]] = vm.cmp.nz.ref %[[IMPORTED]] + // CHECK: return %[[DID_IMPORT]], %[[IMPORTED]] + return %did_import, %buffer : i1, !hal.buffer } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir index 6ffca9b6ec43..ce88e3fa22ff 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir @@ -143,6 +143,41 @@ func.func @device_queue_dealloca( // ----- +// CHECK-LABEL: @device_queue_read +func.func @device_queue_read( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_FILE:.+]]: !vm.ref, + %source_file: !hal.file, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !vm.ref) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = vm.const.i64 100 + %source_offset = arith.constant 100 : i64 + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = vm.const.i64 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 300 + %length = arith.constant 300 : index + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i32.zero + // CHECK: vm.call @hal.device.queue.read( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[SOURCE_FILE]], %[[SOURCE_OFFSET]], + // CHECK-SAME: %[[TARGET_BUFFER]], %[[TARGET_OFFSET]], + // CHECK-SAME: %[[LENGTH]], %[[FLAGS]]) + hal.device.queue.read<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + source(%source_file : !hal.file)[%source_offset] + target(%target_buffer : !hal.buffer)[%target_offset] + length(%length) + flags(0) + return +} + +// ----- + // CHECK-LABEL: @device_queue_execute func.func @device_queue_execute( // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel index 7554932e702a..c9803ce4809d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/BUILD.bazel @@ -33,6 +33,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Transforms", ], ) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt index 5c339fa311b7..1fe1cbe23fa9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( MLIRFuncDialect MLIRIR MLIRPass + MLIRSCFDialect MLIRTransforms iree::compiler::Dialect::HAL::Conversion iree::compiler::Dialect::HAL::IR diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 9e087cb13eeb..95f2cfea7d09 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/DialectConversion.h" @@ -71,6 +72,26 @@ static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) { return allocatorOp.getResult(); } +static std::tuple +lookupAllocatorAndQueueAffinityFor(Operation *op, OpBuilder &builder) { + // NOTE: we have this combined method so that we can reuse any expensive + // lookups we need to do. Today we aren't duplicating the lookups and don't + // bother. + + // Get a device handle used to create resources and schedule work. + // It may be shared across many mutually-exclusive devices at runtime. + Value device = lookupDeviceFor(op, builder); + + // Each device has a single allocator that may itself present multiple. + Value allocator = + builder.create(op->getLoc(), device); + + // Derive the queue affinity mask from the op and device combination. + Value queueAffinity = buildQueueAffinityMaskFor(op, device, builder); + + return std::make_tuple(allocator, queueAffinity); +} + // Returns the |timepointFence| or a util.null. static Value getOrCreateWaitFence(Location loc, Value timepointFence, OpBuilder &builder) { @@ -307,7 +328,8 @@ struct ResourceAllocOpPattern LogicalResult matchAndRewrite(IREE::Stream::ResourceAllocOp allocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto allocator = lookupAllocatorFor(allocOp, rewriter); + auto [allocator, queueAffinity] = + lookupAllocatorAndQueueAffinityFor(allocOp, rewriter); auto bufferType = rewriter.getType(); SmallVector results; @@ -324,8 +346,8 @@ struct ResourceAllocOpPattern } auto allocateOp = rewriter.create( - allocOp.getLoc(), bufferType, allocator, memoryTypes, bufferUsage, - storageSize); + allocOp.getLoc(), bufferType, allocator, queueAffinity, memoryTypes, + bufferUsage, storageSize); results.push_back(allocateOp.getResult()); } @@ -412,37 +434,14 @@ struct ResourceSizeOpPattern } }; -struct ResourceMapOpPattern - : public StreamConversionPattern { - using StreamConversionPattern::StreamConversionPattern; - LogicalResult - matchAndRewrite(IREE::Stream::ResourceMapOp mapOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto allocator = lookupAllocatorFor(mapOp, rewriter); - auto bufferType = rewriter.getType(); - - // We know this is a staging buffer. We could refine usage here by seeing - // whether this was upload or download. - auto memoryTypes = IREE::HAL::MemoryTypeBitfield::HostLocal | - IREE::HAL::MemoryTypeBitfield::DeviceVisible; - auto bufferUsage = IREE::HAL::BufferUsageBitfield::Mapping | - IREE::HAL::BufferUsageBitfield::Transfer; - - rewriter.replaceOpWithNewOp( - mapOp, bufferType, allocator, memoryTypes, bufferUsage, - adaptor.getSource(), adaptor.getSourceOffset(), - adaptor.getResultSize()); - return success(); - } -}; - struct ResourceTryMapOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; LogicalResult matchAndRewrite(IREE::Stream::ResourceTryMapOp tryMapOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto allocator = lookupAllocatorFor(tryMapOp, rewriter); + auto [allocator, queueAffinity] = + lookupAllocatorAndQueueAffinityFor(tryMapOp, rewriter); auto resourceType = llvm::cast(tryMapOp.getResult().getType()); auto bufferType = rewriter.getType(); @@ -459,27 +458,33 @@ struct ResourceTryMapOpPattern memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal; bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::SharingImmutable; + // TODO(benvanik): refine usage based on analysis. + bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer | + IREE::HAL::BufferUsageBitfield::DispatchStorage; + break; + case IREE::Stream::Lifetime::Variable: + // Device local; copies required to get into external resources. + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal; + // TODO(benvanik): refine usage based on analysis. + bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer | + IREE::HAL::BufferUsageBitfield::DispatchStorage; break; case IREE::Stream::Lifetime::Staging: // Host local; copies required to get into device resources. // We could vary this based on staging usage (upload/download) by // making it device-local|host-visible, but host-local means we have // a better chance of mapping it during uploads. - memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostLocal | + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostVisible | IREE::HAL::MemoryTypeBitfield::DeviceVisible; - bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer | - IREE::HAL::BufferUsageBitfield::Mapping; + bufferUsage = + bufferUsage | IREE::HAL::BufferUsageBitfield::TransferSource; break; } - // TODO(benvanik): refine usage based on analysis. - bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer | - IREE::HAL::BufferUsageBitfield::DispatchStorage; - - rewriter.replaceOpWithNewOp( - tryMapOp, rewriter.getI1Type(), bufferType, allocator, memoryTypes, - bufferUsage, adaptor.getSource(), adaptor.getSourceOffset(), - adaptor.getResultSize()); + rewriter.replaceOpWithNewOp( + tryMapOp, rewriter.getI1Type(), bufferType, allocator, queueAffinity, + memoryTypes, bufferUsage, adaptor.getSource(), + adaptor.getSourceOffset(), adaptor.getResultSize()); return success(); } }; @@ -527,6 +532,80 @@ struct ResourceSubviewOpPattern } }; +struct FileConstantOpPattern + : public StreamConversionPattern { + using StreamConversionPattern::StreamConversionPattern; + LogicalResult + matchAndRewrite(IREE::Stream::FileConstantOp constantOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto [device, queueAffinity] = + lookupDeviceAndQueueAffinityFor(constantOp, rewriter); + rewriter.replaceOpWithNewOp( + constantOp, rewriter.getType(), device, + queueAffinity, IREE::HAL::MemoryAccessBitfield::Read, + constantOp.getSource(), constantOp.getSourceOffset(), + constantOp.getSourceLength(), + rewriter.create(constantOp.getLoc(), 0, 32)); + return success(); + } +}; + +struct FileReadOpPattern + : public StreamConversionPattern { + using StreamConversionPattern::StreamConversionPattern; + LogicalResult + matchAndRewrite(IREE::Stream::FileReadOp readOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = readOp.getLoc(); + auto [device, queueAffinity] = + lookupDeviceAndQueueAffinityFor(readOp, rewriter); + + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = getOrCreateSignalFence( + loc, device, readOp.getResultTimepoint(), rewriter); + + // Queue read. + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, adaptor.getSource(), + adaptor.getSourceOffset(), adaptor.getTarget(), + adaptor.getTargetOffset(), adaptor.getLength(), + /*flags=*/0); + + rewriter.replaceOp(readOp, {signalFence}); + return success(); + } +}; + +struct FileWriteOpPattern + : public StreamConversionPattern { + using StreamConversionPattern::StreamConversionPattern; + LogicalResult + matchAndRewrite(IREE::Stream::FileWriteOp writeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = writeOp.getLoc(); + auto [device, queueAffinity] = + lookupDeviceAndQueueAffinityFor(writeOp, rewriter); + + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = getOrCreateSignalFence( + loc, device, writeOp.getResultTimepoint(), rewriter); + + // Queue write. + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, adaptor.getSource(), + adaptor.getSourceOffset(), adaptor.getTarget(), + adaptor.getTargetOffset(), adaptor.getLength(), + /*flags=*/0); + + rewriter.replaceOp(writeOp, {signalFence}); + return success(); + } +}; + // Inserts IR to assert that the underlying buffer storage is compatible with // the intended usage in the program. The allocator used to allocate the // buffer must have compatibility with our target device allocator and the @@ -1416,6 +1495,12 @@ void populateStreamToHALPatterns(MLIRContext *context, return success(); }); + typeConverter.addConversion( + [=](IREE::Stream::FileType type, SmallVectorImpl &results) { + results.push_back(IREE::HAL::FileType::get(context)); + return success(); + }); + typeConverter.addConversion( [=](IREE::Stream::ResourceType type, SmallVectorImpl &results) { // Resources are just buffers (no shape/encoding/etc). @@ -1435,9 +1520,11 @@ void populateStreamToHALPatterns(MLIRContext *context, auto mapping = std::make_shared(); patterns.insert(mapping, typeConverter, context); + ResourceTryMapOpPattern, ResourceLoadOpPattern, + ResourceStoreOpPattern, ResourceSubviewOpPattern>( + mapping, typeConverter, context); + patterns.insert( + mapping, typeConverter, context); patterns.insert(mapping, typeConverter, context); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel index 2cf2072d6cd5..832da446424f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( [ "channel_ops.mlir", "cmd_ops.mlir", + "file_ops.mlir", "resource_ops.mlir", "timepoint_ops.mlir", "transfer_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt index 19aecb7fdfa5..5c25a89a7d7d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "channel_ops.mlir" "cmd_ops.mlir" + "file_ops.mlir" "resource_ops.mlir" "timepoint_ops.mlir" "transfer_ops.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir new file mode 100644 index 000000000000..2ac1385882d5 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/file_ops.mlir @@ -0,0 +1,41 @@ +// RUN: iree-opt --split-input-file --iree-hal-conversion %s | FileCheck %s + +// CHECK-LABEL: @file_constant +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer) +func.func @file_constant(%buffer: !util.buffer) { + %c0 = arith.constant 0 : index + %c1088 = arith.constant 1088 : index + // CHECK: = hal.ex.file.from_memory device(%device : !hal.device) affinity(%c-1_i64) access(Read) buffer(%[[BUFFER]] : !util.buffer)[%c0 for %c1088] flags(%c0_i32) : !hal.file + %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + return +} + +// ----- + +// CHECK-LABEL: @file_read +// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) +func.func @file_read(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1088 = arith.constant 1088 : index + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + // CHECK: hal.device.queue.read<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[FILE]] : !hal.file)[%c0_i64] target(%[[RESOURCE]] : !hal.buffer)[%c0] length(%c1088) flags(0) + %signal = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint + // CHECK: return %[[SIGNAL]] + return %signal : !stream.timepoint +} + +// ----- + +// CHECK-LABEL: @file_write +// CHECK-SAME: (%[[WAIT:.+]]: !hal.fence, %[[FILE:.+]]: !hal.file, %[[RESOURCE:.+]]: !hal.buffer) +func.func @file_write(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1088 = arith.constant 1088 : index + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + // CHECK: hal.device.queue.write<%device : !hal.device> affinity(%c-1_i64) wait(%[[WAIT]]) signal(%[[SIGNAL]]) source(%[[RESOURCE]] : !hal.buffer)[%c0] target(%[[FILE]] : !hal.file)[%c0_i64] length(%c1088) flags(0) + %signal = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource{%c1088} -> !stream.file => !stream.timepoint + // CHECK: return %[[SIGNAL]] + return %signal : !stream.timepoint +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir index 5fc5f9d08593..e70f8b04263a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/resource_ops.mlir @@ -101,31 +101,16 @@ func.func @resourceSize(%arg0: !stream.resource) -> index { // ----- -// CHECK-LABEL: @resourceMap -func.func @resourceMap(%arg0: !util.buffer) -> !stream.resource { - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - // CHECK: %[[MAPPING:.+]] = hal.allocator.allocate.initialized - // CHECK-SAME: source(%arg0 : !util.buffer)[%c0, %c128] - // CHECK-SAME: type("HostVisible|HostCoherent|HostLocal|DeviceVisible") - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Mapping{{.+}}") : !hal.buffer - %mapping = stream.resource.map %arg0[%c0] : !util.buffer -> !stream.resource{%c128} - // CHECK: return %[[MAPPING]] - return %mapping : !stream.resource -} - -// ----- - // CHECK-LABEL: @resourceTryMap func.func @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource) { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index - // CHECK: %[[DID_MAP:.+]], %[[MAPPING:.+]] = hal.allocator.try_map + // CHECK: %[[DID_IMPORT:.+]], %[[IMPORTED:.+]] = hal.allocator.import // CHECK-SAME: source(%arg0 : !util.buffer)[%c0, %c128] // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}SharingImmutable") : i1, !hal. %did_map, %mapping = stream.resource.try_map %arg0[%c0] : !util.buffer -> i1, !stream.resource{%c128} - // CHECK: return %[[DID_MAP]], %[[MAPPING]] + // CHECK: return %[[DID_IMPORT]], %[[IMPORTED]] return %did_map, %mapping : i1, !stream.resource } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td index 8be7d61acdf9..e01d81e771ee 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -62,6 +62,26 @@ def HAL_MemoryTypeBitfieldAttr : let cppNamespace = "mlir::iree_compiler::IREE::HAL"; } +def HAL_MemoryAccess_None : I32BitEnumAttrCase<"None", 0x00000000>; +def HAL_MemoryAccess_Read : I32BitEnumAttrCase<"Read", 0x00000001>; +def HAL_MemoryAccess_Write : I32BitEnumAttrCase<"Write", 0x00000002>; +def HAL_MemoryAccess_Discard : I32BitEnumAttrCase<"Discard", 0x00000004>; +def HAL_MemoryAccess_MayAlias : I32BitEnumAttrCase<"MayAlias", 0x00000008>; +def HAL_MemoryAccess_Unaligned : I32BitEnumAttrCase<"Unaligned", 0x00000010>; +def HAL_MemoryAccess_Any : I32BitEnumAttrCase<"Any", 0x00000020>; +def HAL_MemoryAccessBitfieldAttr : + I32BitEnumAttr<"MemoryAccessBitfield", "valid MemoryAccess", [ + HAL_MemoryAccess_None, + HAL_MemoryAccess_Read, + HAL_MemoryAccess_Write, + HAL_MemoryAccess_Discard, + HAL_MemoryAccess_MayAlias, + HAL_MemoryAccess_Unaligned, + HAL_MemoryAccess_Any, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + def HAL_BufferUsage_None : I32BitEnumAttrCase<"None", 0x00000000>; def HAL_BufferUsage_TransferSource : I32BitEnumAttrCase<"TransferSource", 0x00000001>; def HAL_BufferUsage_TransferTarget : I32BitEnumAttrCase<"TransferTarget", 0x00000002>; @@ -422,16 +442,6 @@ def HAL_Executable : DialectType< let builderCall = "$_builder.getType()"; } -def HAL_PipelineLayout : DialectType< - HAL_Dialect, - CPred<"$_self.isa()">, - "pipeline_layout"> { - let description = [{ - An pipeline layout describing the descriptor sets and push constants used. - }]; - let builderCall = "$_builder.getType()"; -} - def HAL_Fence : DialectType< HAL_Dialect, CPred<"$_self.isa()">, @@ -443,6 +453,27 @@ def HAL_Fence : DialectType< let builderCall = "$_builder.getType()"; } +def HAL_File : DialectType< + HAL_Dialect, + CPred<"$_self.isa()">, + "buffer"> { + let description = [{ + A stateless file handle that can be read/written using queue-ordered + transfer operations. + }]; + let builderCall = "$_builder.getType()"; +} + +def HAL_PipelineLayout : DialectType< + HAL_Dialect, + CPred<"$_self.isa()">, + "pipeline_layout"> { + let description = [{ + A pipeline layout describing the descriptor sets and push constants used. + }]; + let builderCall = "$_builder.getType()"; +} + def HAL_ObjectType : AnyTypeOf<[ HAL_Allocator, HAL_Buffer, @@ -453,6 +484,7 @@ def HAL_ObjectType : AnyTypeOf<[ HAL_Event, HAL_Executable, HAL_Fence, + HAL_File, HAL_PipelineLayout, ]>; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index d702e1de761d..fd4c86713d29 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -119,7 +119,7 @@ static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op, } //===----------------------------------------------------------------------===// -// hal.ex.shared_device +// hal.ex.* //===----------------------------------------------------------------------===// void ExSharedDeviceOp::getAsmResultNames( @@ -127,6 +127,11 @@ void ExSharedDeviceOp::getAsmResultNames( setNameFn(getResult(), "device"); } +void ExFileFromMemoryOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "memory_file"); +} + //===----------------------------------------------------------------------===// // hal.return //===----------------------------------------------------------------------===// @@ -335,35 +340,18 @@ Value AllocatorAllocateOp::getResultSize(unsigned idx) { } //===----------------------------------------------------------------------===// -// hal.allocator.allocate.initialized -//===----------------------------------------------------------------------===// - -void AllocatorAllocateInitializedOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(getResult(), "mapped"); -} - -Value AllocatorAllocateInitializedOp::getOperandSize(unsigned idx) { - return {}; -} - -Value AllocatorAllocateInitializedOp::getResultSize(unsigned idx) { - return getLength(); -} - -//===----------------------------------------------------------------------===// -// hal.allocator.try_map +// hal.allocator.import //===----------------------------------------------------------------------===// -void AllocatorTryMapOp::getAsmResultNames( +void AllocatorImportOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(getDidMap(), "did_map"); + setNameFn(getDidImport(), "did_import"); setNameFn(getResult(), "mapped"); } -Value AllocatorTryMapOp::getOperandSize(unsigned idx) { return {}; } +Value AllocatorImportOp::getOperandSize(unsigned idx) { return {}; } -Value AllocatorTryMapOp::getResultSize(unsigned idx) { return getLength(); } +Value AllocatorImportOp::getResultSize(unsigned idx) { return getLength(); } //===----------------------------------------------------------------------===// // hal.buffer.subspan @@ -681,6 +669,14 @@ LogicalResult DeviceQueueDeallocaOp::verify() { return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); } +LogicalResult DeviceQueueReadOp::verify() { + return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); +} + +LogicalResult DeviceQueueWriteOp::verify() { + return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); +} + LogicalResult DeviceQueueExecuteOp::verify() { return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 6c4844f1e5e2..1f83ba3e4e11 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -55,6 +55,51 @@ def HAL_ExSharedDeviceOp : HAL_PureOp<"ex.shared_device", [ ]; } +def HAL_ExFileFromMemoryOp : HAL_Op<"ex.file.from_memory", [ + DeclareOpInterfaceMethods, + ]> { + let summary = [{creates a file mapped into a byte range of a host buffer}]; + let description = [{ + Returns a file handle that is backed by the given `buffer` contents. + Behavior is undefined if the buffer contents change while the accesses are + in-flight. + + Experimental as the exact interface for getting files from module contents + still needs iteration. Most hardware APIs require a file descriptor or + native platform handle but here we only have host pointers. When + memory-mapped some systems allow for retrieval of the platform handle from + a virtual address (GetMappedFileNameA/posix_mem_offset) but the APIs are + sketchy and likely slow. Instead we should probably have a way to query for + a file handle derived from the calling module by stack-walking and asking + the VM module for its handle. Until we can figure this out this method will + be marked epxerimental. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_MemoryAccessBitfieldAttr:$access, + Util_BufferType:$buffer, + HAL_DeviceSize:$offset, + HAL_DeviceSize:$length, + I32:$flags + ); + let results = (outs + HAL_File:$result + ); + + let assemblyFormat = [{ + `device` `(` $device `:` type($device) `)` + `affinity` `(` $queue_affinity `)` + `access` `(` $access `)` + `buffer` `(` $buffer `:` type($buffer) `)` + `` `[` $offset `for` $length `]` + `flags` `(` $flags `)` + `:` type($result) + attr-dict-with-keyword + }]; +} + } // OpGroupExperimentalOps //===----------------------------------------------------------------------===// @@ -270,6 +315,7 @@ def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [ let arguments = (ins HAL_Allocator:$allocator, + HAL_DeviceQueueAffinity:$queue_affinity, HAL_MemoryTypeBitfieldAttr:$memory_types, HAL_BufferUsageBitfieldAttr:$buffer_usage, HAL_DeviceSize:$result_size @@ -281,6 +327,7 @@ def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [ // TODO(benvanik): change type/usage to ref params. let assemblyFormat = [{ `<` $allocator `:` type($allocator) `>` + `affinity` `(` $queue_affinity `)` `type` `(` $memory_types `)` `usage` `(` $buffer_usage `)` `:` custom(type($result), $result_size) @@ -288,56 +335,22 @@ def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [ }]; } -def HAL_AllocatorAllocateInitializedOp : HAL_Op<"allocator.allocate.initialized", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - ]> { - let summary = [{allocator-supported host buffer wrapping operation}]; - let description = [{ - Wraps a !hal.buffer around host read-only memory backed by the given byte - buffer. The returned buffer may be host-only and not directly usable on - devices. - }]; - - let arguments = (ins - HAL_Allocator:$allocator, - HAL_MemoryTypeBitfieldAttr:$memory_types, - HAL_BufferUsageBitfieldAttr:$buffer_usage, - // TODO(benvanik): support other types (and mutable buffers). - Util_BufferType:$source, - HAL_DeviceSize:$offset, - HAL_DeviceSize:$length - ); - let results = (outs - HAL_Buffer:$result - ); - - // TODO(benvanik): change type/usage to ref params. - let assemblyFormat = [{ - `<` $allocator `:` type($allocator) `>` - `source` `(` $source `:` type($source) `)` `` `[` $offset `,` $length `]` - `type` `(` $memory_types `)` - `usage` `(` $buffer_usage `)` - `:` type($result) - attr-dict-with-keyword - }]; -} - -def HAL_AllocatorTryMapOp : HAL_Op<"allocator.try_map", [ +def HAL_AllocatorImportOp : HAL_Op<"allocator.import", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ]> { - let summary = [{allocator-supported host buffer wrapping operation}]; + let summary = [{allocator-supported host buffer import operation}]; let description = [{ - Tries wrapping a !hal.buffer around host read-only memory backed by the - given byte buffer. The returned buffer may be host-only and not directly - usable on devices. If the mapping cannot be completed (such as trying to - map the host memory as device-local on devices with discrete memory) then - did_map will indicate that the returned buffer is null. + Tries importing host memory backed by the given byte buffer into a + device accessible `!hal.buffer`. The returned buffer may be host-only and + not directly usable on devices. If the mapping cannot be completed (such as + trying to map the host memory as device-local on devices with discrete + memory) then `did_import` will indicate that the returned buffer is null. }]; let arguments = (ins HAL_Allocator:$allocator, + HAL_DeviceQueueAffinity:$queue_affinity, HAL_MemoryTypeBitfieldAttr:$memory_types, HAL_BufferUsageBitfieldAttr:$buffer_usage, // TODO(benvanik): support other types (and mutable buffers). @@ -346,7 +359,7 @@ def HAL_AllocatorTryMapOp : HAL_Op<"allocator.try_map", [ HAL_DeviceSize:$length ); let results = (outs - I1:$did_map, + I1:$did_import, HAL_Buffer:$result ); @@ -354,9 +367,10 @@ def HAL_AllocatorTryMapOp : HAL_Op<"allocator.try_map", [ let assemblyFormat = [{ `<` $allocator `:` type($allocator) `>` `source` `(` $source `:` type($source) `)` `` `[` $offset `,` $length `]` + `affinity` `(` $queue_affinity `)` `type` `(` $memory_types `)` `usage` `(` $buffer_usage `)` - `:` type($did_map) `,` type($result) + `:` type($did_import) `,` type($result) attr-dict-with-keyword }]; } @@ -1621,6 +1635,90 @@ def HAL_DeviceQueueDeallocaOp : HAL_Op<"device.queue.dealloca"> { let hasVerifier = 1; } +def HAL_DeviceQueueReadOp : HAL_Op<"device.queue.read"> { + let summary = [{reads a segment from a file into a device buffer}]; + let description = [{ + Enqueues a file read operation that streams a segment of the source file + defined by the source offset and length into the target HAL buffer at the + specified target offset. The queue affinity should be set to where the + target buffer will be consumed. The source file must have read permission + and the target buffer must have transfer-target usage. Read failure will + result in propagated semaphore failure or device loss. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_File:$source_file, + I64:$source_offset, + HAL_Buffer:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length, + I32Attr:$flags + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `source` `(` $source_file `:` type($source_file) `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + +def HAL_DeviceQueueWriteOp : HAL_Op<"device.queue.write"> { + let summary = [{writes a segment from a device buffer into a file}]; + let description = [{ + Enqueues a file write operation that streams a segment of the source HAL + buffer defined by the source offset and length into the target file at the + specified target offset. The queue affinity should be set to where the + source buffer was produced. The source buffer must have transfer-source + usage and the target file must have write permission. Write failure will + result in propagated semaphore failure or device loss. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_Buffer:$source_buffer, + HAL_DeviceSize:$source_offset, + HAL_File:$target_file, + I64:$target_offset, + HAL_DeviceSize:$length, + I32Attr:$flags + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `source` `(` $source_buffer `:` type($source_buffer) `)` + `` `[` $source_offset `]` + `target` `(` $target_file `:` type($target_file) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + def HAL_DeviceQueueExecuteOp : HAL_Op<"device.queue.execute"> { let summary = [{enqueues command buffer execution}]; let description = [{ diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 4a15790b55a3..6f436fe32984 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -1104,7 +1104,7 @@ void HALDialect::registerAttributes() { void HALDialect::registerTypes() { addTypes(); } @@ -1153,9 +1153,9 @@ Type HALDialect::parseType(DialectAsmParser &parser) const { .Case("device", DeviceType::get(getContext())) .Case("event", EventType::get(getContext())) .Case("executable", ExecutableType::get(getContext())) - .Case("pipeline_layout", PipelineLayoutType::get(getContext())) .Case("fence", FenceType::get(getContext())) - .Case("ring_buffer", RingBufferType::get(getContext())) + .Case("file", FileType::get(getContext())) + .Case("pipeline_layout", PipelineLayoutType::get(getContext())) .Case("semaphore", SemaphoreType::get(getContext())) .Default(nullptr); if (!type) { @@ -1184,12 +1184,12 @@ void HALDialect::printType(Type type, DialectAsmPrinter &p) const { p << "event"; } else if (llvm::isa(type)) { p << "executable"; - } else if (llvm::isa(type)) { - p << "pipeline_layout"; } else if (llvm::isa(type)) { p << "fence"; - } else if (llvm::isa(type)) { - p << "ring_buffer"; + } else if (llvm::isa(type)) { + p << "file"; + } else if (llvm::isa(type)) { + p << "pipeline_layout"; } else if (llvm::isa(type)) { p << "semaphore"; } else { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index ad7123a2a104..e77d258f5993 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h @@ -122,17 +122,16 @@ struct ExecutableType using Base::Base; }; -struct PipelineLayoutType - : public Type::TypeBase { +struct FenceType : public Type::TypeBase { using Base::Base; }; -struct FenceType : public Type::TypeBase { +struct FileType : public Type::TypeBase { using Base::Base; }; -struct RingBufferType - : public Type::TypeBase { +struct PipelineLayoutType + : public Type::TypeBase { using Base::Base; }; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir index edc0d2de0790..cd35bf681ce7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir @@ -3,6 +3,8 @@ // CHECK-LABEL: @allocator_allocate // CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator) func.func @allocator_allocate(%allocator: !hal.allocator) { + // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 + %affinity = arith.constant -1 : i64 // CHECK-DAG: %[[SIZE:.+]] = arith.constant 123 %size = arith.constant 123 : index // CHECK: %[[REF:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> @@ -10,26 +12,29 @@ func.func @allocator_allocate(%allocator: !hal.allocator) { // CHECK-SAME: usage("TransferSource|TransferTarget|Transfer") // CHECK-SAME: : !hal.buffer{%[[SIZE]]} %ref = hal.allocator.allocate<%allocator : !hal.allocator> - type(HostLocal) usage(Transfer) : !hal.buffer{%size} + affinity(%affinity) type(HostLocal) usage(Transfer) : !hal.buffer{%size} return } // ----- -// CHECK-LABEL: @allocator_map_byte_buffer +// CHECK-LABEL: @allocator_import // CHECK-SAME: %[[ALLOCATOR:.+]]: !hal.allocator -func.func @allocator_map_byte_buffer(%allocator: !hal.allocator, %arg1: !util.buffer) { +func.func @allocator_import(%allocator: !hal.allocator, %arg1: !util.buffer) { // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 %offset = arith.constant 100 : index // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 200 %length = arith.constant 200 : index - // CHECK: = hal.allocator.allocate.initialized<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 + %affinity = arith.constant -1 : i64 + // CHECK: = hal.allocator.import<%[[ALLOCATOR]] : !hal.allocator> // CHECK-SAME: source(%arg1 : !util.buffer)[%[[OFFSET]], %[[LENGTH]]] + // CHECK-SAME: affinity(%[[AFFINITY]]) // CHECK-SAME: type("DeviceVisible|DeviceLocal") // CHECK-SAME: usage("TransferSource|TransferTarget|Transfer") - // CHECK-SAME: : !hal.buffer - %ref = hal.allocator.allocate.initialized<%allocator : !hal.allocator> - source(%arg1 : !util.buffer)[%offset, %length] - type(DeviceLocal) usage(Transfer) : !hal.buffer + // CHECK-SAME: : i1, !hal.buffer + %ok, %ref = hal.allocator.import<%allocator : !hal.allocator> + source(%arg1 : !util.buffer)[%offset, %length] + affinity(%affinity) type(DeviceLocal) usage(Transfer) : i1, !hal.buffer return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index 94ba6d0f5f5d..c163e1440d15 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir @@ -60,7 +60,7 @@ func.func @device_queue_alloca( %device: !hal.device, %affinity: i64, // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, %wait_fence: !hal.fence, %signal_fence: !hal.fence, - // CHECK-SAME: %[[SIZE:.+]]: index) + // CHECK-SAME: %[[SIZE:.+]]: index) %size: index) -> !hal.buffer { %c100_i64 = arith.constant 100 : i64 // CHECK: = hal.device.queue.alloca<%[[DEVICE]] : !hal.device> @@ -86,7 +86,7 @@ func.func @device_queue_dealloca( %device: !hal.device, %affinity: i64, // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, %wait_fence: !hal.fence, %signal_fence: !hal.fence, - // CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer) + // CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer) %buffer: !hal.buffer) { // CHECK: hal.device.queue.dealloca<%[[DEVICE]] : !hal.device> hal.device.queue.dealloca<%device : !hal.device> @@ -101,13 +101,83 @@ func.func @device_queue_dealloca( // ----- +// CHECK-LABEL: @device_queue_read +func.func @device_queue_read( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_FILE:.+]]: !hal.file, + %source_file: !hal.file, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = arith.constant 100 + %source_offset = arith.constant 100 : i64 + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = arith.constant 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 300 + %length = arith.constant 300 : index + // CHECK: hal.device.queue.read<%[[DEVICE]] : !hal.device> + hal.device.queue.read<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: source(%[[SOURCE_FILE]] : !hal.file)[%[[SOURCE_OFFSET]]] + source(%source_file : !hal.file)[%source_offset] + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%[[TARGET_OFFSET]]] + target(%target_buffer : !hal.buffer)[%target_offset] + // CHECK-SAME: length(%[[LENGTH]]) + length(%length) + // CHECK-SAME: flags(0) + flags(0) + return +} + +// ----- + +// CHECK-LABEL: @device_queue_write +func.func @device_queue_write( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !hal.buffer, + %source_buffer: !hal.buffer, + // CHECK-SAME: %[[TARGET_FILE:.+]]: !hal.file) + %target_file: !hal.file) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = arith.constant 100 + %source_offset = arith.constant 100 : index + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = arith.constant 200 + %target_offset = arith.constant 200 : i64 + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 300 + %length = arith.constant 300 : index + // CHECK: hal.device.queue.write<%[[DEVICE]] : !hal.device> + hal.device.queue.write<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !hal.buffer)[%[[SOURCE_OFFSET]]] + source(%source_buffer : !hal.buffer)[%source_offset] + // CHECK-SAME: target(%[[TARGET_FILE]] : !hal.file)[%[[TARGET_OFFSET]]] + target(%target_file : !hal.file)[%target_offset] + // CHECK-SAME: length(%[[LENGTH]]) + length(%length) + // CHECK-SAME: flags(0) + flags(0) + return +} + +// ----- + // CHECK-LABEL: @device_queue_execute func.func @device_queue_execute( // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, %device: !hal.device, %affinity: i64, // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, %wait_fence: !hal.fence, %signal_fence: !hal.fence, - // CHECK-SAME: %[[CMD0:.+]]: !hal.command_buffer, %[[CMD1:.+]]: !hal.command_buffer) + // CHECK-SAME: %[[CMD0:.+]]: !hal.command_buffer, %[[CMD1:.+]]: !hal.command_buffer) %cmd0: !hal.command_buffer, %cmd1: !hal.command_buffer) { // CHECK: hal.device.queue.execute<%[[DEVICE]] : !hal.device> hal.device.queue.execute<%device : !hal.device> diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir index 166cb019677e..3fb6585837c1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir @@ -6,3 +6,32 @@ func.func @shared_device() -> !hal.device { %device = hal.ex.shared_device : !hal.device return %device : !hal.device } + +// ----- + +// CHECK-LABEL: @file_from_memory +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[BUFFER:.+]]: !util.buffer) +func.func @file_from_memory(%device: !hal.device, %buffer: !util.buffer) -> !hal.file { + // CHECK-DAG: %[[AFFINITY:.+]] = arith.constant -1 + %affinity = arith.constant -1 : i64 + // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 + %offset = arith.constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 200 + %length = arith.constant 200 : index + // CHECK-DAG: %[[FLAGS:.+]] = arith.constant 0 : i32 + %flags = arith.constant 0 : i32 + // CHECK: = hal.ex.file.from_memory + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: affinity(%[[AFFINITY]]) + // CHECK-SAME: access(Read) + // CHECK-SAME: buffer(%[[BUFFER]] : !util.buffer) + // CHECK-SAME: [%[[OFFSET]] for %[[LENGTH]]] + // CHECK-SAME: flags(%[[FLAGS]]) : !hal.file + %file = hal.ex.file.from_memory + device(%device : !hal.device) + affinity(%affinity) + access(Read) + buffer(%buffer : !util.buffer)[%offset for %length] + flags(%flags) : !hal.file + return %file : !hal.file +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp index 63b2a4c72dbe..7809131dea8a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConvertToHAL.cpp @@ -23,6 +23,7 @@ #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -44,8 +45,9 @@ class ConvertToHALPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToHALPass) void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 9976784be964..a8232ac3834d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -165,12 +165,13 @@ appendGlobalBuffer(Location loc, StringRef baseName, auto allocator = initBuilder.create(loc, device).getResult(); + auto queueAffinity = initBuilder.create(loc, -1, 64); auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal; auto bufferUsage = IREE::HAL::BufferUsageBitfield::Transfer | IREE::HAL::BufferUsageBitfield::DispatchStorage; auto allocateOp = initBuilder.create( - loc, globalOp.getType(), allocator, memoryTypes, bufferUsage, - indexSet.get(totalLength)); + loc, globalOp.getType(), allocator, queueAffinity, memoryTypes, + bufferUsage, indexSet.get(totalLength)); initBuilder.create(loc, allocateOp.getResult(), globalOp.getNameAttr()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir index c13bb5927221..fb1d3287a463 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/dump_executable_benchmarks.mlir @@ -63,7 +63,7 @@ module attributes {hal.device.targets = [#device_target_cpu]} { // CHECK: util.global private mutable @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer // CHECK-NEXT: util.initializer { - // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768} + // CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%{{.+}} : !hal.allocator> affinity(%{{.+}}) type("DeviceVisible|DeviceLocal") usage("{{.+}}Dispatch{{.+}}") : !hal.buffer{%c768} // CHECK-NEXT: util.global.store %[[BUFFER]], @ex0_embedded_elf_x86_64_dispatch0_512_buffer : !hal.buffer // CHECK: func.func @ex0_embedded_elf_x86_64_dispatch0_512(%arg0: i32) diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 9d2be1d4bb11..5166f5a17b67 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -11,41 +11,47 @@ vm.module @hal { vm.import private @ex.shared_device() -> !vm.ref attributes {nosideeffects} +// Creates a file mapped into a byte range of a host buffer. +// EXPERIMENTAL: may be removed in future versions. +vm.import private @ex.file.from_memory( + %device : !vm.ref, + %queue_affinity : i64, + %access : i32, + %buffer : !vm.buffer, + %offset : i64, + %length : i64, + %flags : i32 +) -> !vm.ref + //===----------------------------------------------------------------------===// // iree_hal_allocator_t //===----------------------------------------------------------------------===// -// Allocates a buffer from the allocator. +// Allocates a buffer from the allocator. The resulting buffer will have a +// length of at least that requested. vm.import private @allocator.allocate( %allocator : !vm.ref, + %queue_affinity : i64, %memory_types : i32, %buffer_usage : i32, %allocation_size : i64 ) -> !vm.ref +attributes {minimum_version = 1 : i32} -// Allocates a buffer from the allocator with an initial value provided by a -// VM byte buffer. -vm.import private @allocator.allocate.initialized( - %allocator : !vm.ref, - %memory_types : i32, - %buffer_usage : i32, - %source : !vm.buffer, - %offset : i64, - %length : i64 -) -> !vm.ref - -// Maps a host byte buffer into a device buffer. +// Imports a host byte buffer into a device visible buffer. // If try!=0 then returns null if the given memory type cannot be mapped. // Host-local+constant requests will always succeed. -vm.import private @allocator.map.byte_buffer( +vm.import private @allocator.import( %allocator : !vm.ref, %try : i32, + %queue_affinity : i64, %memory_types : i32, %buffer_usage : i32, %source : !vm.buffer, %offset : i64, %length : i64 ) -> !vm.ref +attributes {minimum_version = 1 : i32} //===----------------------------------------------------------------------===// // iree_hal_buffer_t @@ -364,6 +370,34 @@ vm.import private @device.queue.dealloca( %buffer : !vm.ref ) +// Reads a segment from a file into a device buffer. +vm.import private @device.queue.read( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %source_file : !vm.ref, + %source_offset : i64, + %target_buffer : !vm.ref, + %target_offset : i64, + %length : i64, + %flags : i32 +) + +// Writes a segment from device buffer into a file. +vm.import private @device.queue.write( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %source_buffer : !vm.ref, + %source_offset : i64, + %target_file : !vm.ref, + %target_offset : i64, + %length : i64, + %flags : i32 +) + // Executes one or more command buffers on a device queue. // The command buffers are executed in order as if they were recorded as one. // No commands will execute until the wait fence has been reached and the signal diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td index 06eb330ae809..cd0495d4a5f3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td @@ -618,6 +618,28 @@ def Stream_AnyStreamResource : AnyTypeOf<[ Stream_ConstantResource, ]>; +//===----------------------------------------------------------------------===// +// File resources +//===----------------------------------------------------------------------===// + +def Stream_File : TypeDef { + let mnemonic = "file"; + + let summary = [{a file handle used for I/O operations}]; + let description = [{ + A file handle that can be asynchronously read and written into/from + stream resources. + }]; + + let parameters = (ins); + + let builders = [ + TypeBuilder<(ins), [{ + return $_get($_ctxt); + }]>, + ]; +} + //===----------------------------------------------------------------------===// // Executable bindings //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 5ad454f4ba39..3951a0377bcc 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -667,16 +667,6 @@ void ResourceSizeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.insert(context); } -//===----------------------------------------------------------------------===// -// stream.resource.map -//===----------------------------------------------------------------------===// - -void ResourceMapOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - // TODO(benvanik): fold subviews up into maps to limit range. - results.insert>(context); -} - //===----------------------------------------------------------------------===// // stream.resource.try_map //===----------------------------------------------------------------------===// @@ -998,6 +988,24 @@ void ResourceSubviewOp::getCanonicalizationPatterns(RewritePatternSet &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// stream.file.read +//===----------------------------------------------------------------------===// + +void FileReadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert>(context); +} + +//===----------------------------------------------------------------------===// +// stream.file.write +//===----------------------------------------------------------------------===// + +void FileWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert>(context); +} + //===----------------------------------------------------------------------===// // stream.tensor.import //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 9ea09388361a..4de2d3139359 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -633,18 +633,6 @@ LogicalResult ResourceAllocOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// stream.resource.map -//===----------------------------------------------------------------------===// - -LogicalResult ResourceMapOp::verify() { - ResourceMapOp op = *this; - if (failed(verifyOpValueSizes(op, op.getResult(), op.getResultSize()))) { - return failure(); - } - return success(); -} - //===----------------------------------------------------------------------===// // stream.resource.try_map //===----------------------------------------------------------------------===// @@ -791,6 +779,58 @@ IREE::Stream::ResourceSubviewOp ResourceSubviewOp::findSubviewOp(Value value) { return {}; } +//===----------------------------------------------------------------------===// +// stream.file.constant +//===----------------------------------------------------------------------===// + +void FileConstantOp::getAsmResultNames(mlir::OpAsmSetValueNameFn setNameFn) { + setNameFn(getResult(), "file"); +} + +IREE::Util::SubrangeOperand +FileConstantOp::getSubrangeOperand(unsigned operandIndex) { + if (operandIndex == 0) { + return IREE::Util::SubrangeOperand{getSource(), getSourceSize(), + getSourceOffset(), getSourceLength()}; + } else { + assert(false && "only source is a subrange"); + return {}; + } +} + +void FileConstantOp::setSubrangeOperand(unsigned operandIndex, + IREE::Util::SubrangeOperand operand) { + assert(operandIndex == 0 && "only source is a subrange"); + getSourceMutable().assign(operand.resource); + getSourceSizeMutable().assign(operand.resourceSize); + getSourceOffsetMutable().assign(operand.offset); + getSourceLengthMutable().assign(operand.length); +} + +//===----------------------------------------------------------------------===// +// stream.file.read +//===----------------------------------------------------------------------===// + +LogicalResult FileReadOp::verify() { + FileReadOp op = *this; + if (failed(verifyOpValueSizes(op, op.getTarget(), op.getTargetSize()))) { + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// stream.file.write +//===----------------------------------------------------------------------===// + +LogicalResult FileWriteOp::verify() { + FileWriteOp op = *this; + if (failed(verifyOpValueSizes(op, op.getSource(), op.getSourceSize()))) { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // stream.tensor.import //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index bbf2641a5f71..616fd27d20e0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -227,48 +227,6 @@ def Stream_ResourceSizeOp : Stream_PureOp<"resource.size", [ let hasFolder = 1; } -def Stream_ResourceMapOp : Stream_Op<"resource.map", [ - Stream_AffinityOp, - Util_SizeAwareOp, - MemoryEffects<[MemAlloc]>, -]> { - let summary = [{maps read-only memory into a staging resource}]; - let description = [{ - Synchronously maps a host heap buffer into a stream-accessible staging - resource. Will never fail but may induce a copy if required and as such the - mapped resource is not coherent with the original source buffer: changing - the source buffer after mapping has undefined behavior. - }]; - - let arguments = (ins - Util_BufferType:$source, - Stream_Offset:$source_offset, - Stream_Size:$result_size, - OptionalAttr:$affinity - ); - let results = (outs - Stream_StagingResource:$result - ); - - let assemblyFormat = [{ - (`on` `(` $affinity^ `)`)? - $source `[` $source_offset `]` `:` - type($source) - `->` - type($result) `` `{` $result_size `}` - attr-dict-with-keyword - }]; - - let extraClassDeclaration = [{ - Value getOperandSize(unsigned idx) { return {}; } - Value getResultSize(unsigned idx) { return getResultSize(); } - }]; - - let hasVerifier = 1; - - let hasCanonicalizer = 1; -} - def Stream_ResourceTryMapOp : Stream_PureOp<"resource.try_map", [ Stream_AffinityOp, Util_SizeAwareOp, @@ -277,10 +235,10 @@ def Stream_ResourceTryMapOp : Stream_PureOp<"resource.try_map", [ let summary = [{maps read-only memory into a resource}]; let description = [{ Synchronously maps a host heap buffer into a stream-accessible resource - with constant lifetime. If the given source cannot be mapped into a constant - a failure is returned and the resulting resource value is null. As with - `stream.resource.map` the resulting resource is not coherent with the source - and changes will not be reflected. + with the requested lifetime. If the given source cannot be mapped the + `did_map` result will be 0 and users must find another route into memory + (such as file I/O). The resulting resource is not coherent with the source + and behavior is undefined if the underlying contents change. }]; let arguments = (ins @@ -291,7 +249,7 @@ def Stream_ResourceTryMapOp : Stream_PureOp<"resource.try_map", [ ); let results = (outs I1:$did_map, - Stream_ConstantResource:$result + Stream_AnyStreamResource:$result ); let assemblyFormat = [{ @@ -611,6 +569,176 @@ def Stream_ResourceSubviewOp : Stream_PureOp<"resource.subview", [ } // OpGroupResourceOps +//===----------------------------------------------------------------------===// +// File ops +//===----------------------------------------------------------------------===// + +def OpGroupFileOps : OpDocGroup { + let summary = "File ops"; + let description = "File ops."; +} + +let opDocGroup = OpGroupFileOps in { + +def Stream_FileConstantOp : Stream_PureOp<"file.constant", [ + Stream_AffinityOp, + Util_SizeAwareOp, + MemoryEffects<[MemAlloc]>, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, +]> { + let summary = [{creates a file backed by the provided constant host memory}]; + let description = [{ + Synchronously wraps a host heap buffer into a stream-accessible file handle. + Changing the source buffer after definition has undefined behavior. + }]; + + let arguments = (ins + Util_BufferType:$source, + Util_Size:$source_size, + Stream_Offset:$source_offset, + Stream_Size:$source_length, + OptionalAttr:$affinity + ); + let results = (outs + Stream_File:$result + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + $source `[` $source_offset `for` $source_length `]` `:` + type($source) `` `{` $source_size `}` + `->` + type($result) + attr-dict-with-keyword + }]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return getSourceSize(); } + Value getResultSize(unsigned idx) { return {}; } + }]; +} + +def Stream_FileReadOp : Stream_Op<"file.read", [ + DeclareOpInterfaceMethods, + Stream_CmdPhaseOp, + Stream_TimelineOp, + Util_SizeAwareOp, +]> { + let summary = [{reads a segment of a file into a resource}]; + let description = [{ + Asynchronously reads a segment of a file into a resource. + + Some implementations this can stream directly from the file into + device-local memory and should be preferred to manually staging memory + through host buffers. + }]; + + let arguments = (ins + Stream_File:$source, + I64:$source_offset, + Stream_AnyStreamResource:$target, + Stream_Size:$target_size, + Stream_Offset:$target_offset, + Stream_Size:$length, + Optional:$await_timepoint, + OptionalAttr:$affinity + ); + let results = (outs + Stream_Timepoint:$result_timepoint + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + (`await` `(` $await_timepoint^ `)` `=` `` `>`):(`:`)? + $source `[` $source_offset `]` `,` + $target `[` $target_offset `]` `,` + $length `:` + type($source) `->` + type($target) `` `{` $target_size `}` + `=` `` `>` + type($result_timepoint) + attr-dict-with-keyword + }]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return getTargetSize(); } + Value getResultSize(unsigned idx) { return {}; } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } + }]; + + let hasVerifier = 1; + + let hasCanonicalizer = 1; +} + +def Stream_FileWriteOp : Stream_Op<"file.write", [ + DeclareOpInterfaceMethods, + Stream_CmdPhaseOp, + Stream_TimelineOp, + Util_SizeAwareOp, +]> { + let summary = [{writes a segment of a file from a resource}]; + let description = [{ + Asynchronously writes a segment of a resource into a file. + The file range must be valid within the file as this operation cannot + grow the underlying file storage. + + Some implementations this can stream directly from device-local memory into + the file and should be preferred to manually staging memory + through host buffers. + }]; + + let arguments = (ins + Stream_AnyStreamResource:$source, + Stream_Size:$source_size, + Stream_Offset:$source_offset, + Stream_File:$target, + I64:$target_offset, + Stream_Size:$length, + Optional:$await_timepoint, + OptionalAttr:$affinity + ); + let results = (outs + Stream_Timepoint:$result_timepoint + ); + + let assemblyFormat = [{ + (`on` `(` $affinity^ `)`)? + (`await` `(` $await_timepoint^ `)` `=` `` `>`):(`:`)? + $source `[` $source_offset `]` `,` + $target `[` $target_offset `]` `,` + $length `:` + type($source) `` `{` $source_size `}` `->` + type($target) + `=` `` `>` + type($result_timepoint) + attr-dict-with-keyword + }]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return getSourceSize(); } + Value getResultSize(unsigned idx) { return {}; } + SmallVector getAwaitTimepoints() { + if (getAwaitTimepoint()) return {getAwaitTimepoint()}; else return {}; + } + }]; + + let hasVerifier = 1; + + let hasCanonicalizer = 1; +} + +} // OpGroupFileOps + //===----------------------------------------------------------------------===// // Pseudo ops for conversion support //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel index 8470c43cc03c..e168414ec75d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel @@ -23,6 +23,7 @@ iree_lit_test_suite( "cmd_folding.mlir", "cmd_ops.mlir", "executable_ops.mlir", + "file_ops.mlir", "resource_folding.mlir", "resource_ops.mlir", "tensor_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt index 3b513a831e9c..4835de2f811d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "cmd_folding.mlir" "cmd_ops.mlir" "executable_ops.mlir" + "file_ops.mlir" "resource_folding.mlir" "resource_ops.mlir" "tensor_folding.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/file_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/file_ops.mlir new file mode 100644 index 000000000000..dbed4df040dc --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/file_ops.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s + +// CHECK-LABEL: @file_constant +// CHECK-SAME: (%[[BUFFER:.+]]: !util.buffer) +func.func @file_constant(%buffer: !util.buffer) { + %c0 = arith.constant 0 : index + %c1088 = arith.constant 1088 : index + // CHECK: %file = stream.file.constant %[[BUFFER]][%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + %file = stream.file.constant %buffer[%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + return +} + +// ----- + +// CHECK-LABEL: @file_read +// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[FILE:.+]]: !stream.file, %[[RESOURCE:.+]]: !stream.resource) +func.func @file_read(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1088 = arith.constant 1088 : index + // CHECK: = stream.file.read await(%[[WAIT]]) => %[[FILE]][%c0_i64], %[[RESOURCE]][%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint + %0 = stream.file.read await(%wait) => %file[%c0_i64], %resource[%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint + return +} + +// ----- + +// CHECK-LABEL: @file_write +// CHECK-SAME: (%[[WAIT:.+]]: !stream.timepoint, %[[FILE:.+]]: !stream.file, %[[RESOURCE:.+]]: !stream.resource) +func.func @file_write(%wait: !stream.timepoint, %file: !stream.file, %resource: !stream.resource) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1088 = arith.constant 1088 : index + // CHECK: = stream.file.write await(%[[WAIT]]) => %[[RESOURCE]][%c0], %[[FILE]][%c0_i64], %c1088 : !stream.resource{%c1088} -> !stream.file => !stream.timepoint + %0 = stream.file.write await(%wait) => %resource[%c0], %file[%c0_i64], %c1088 : !stream.resource{%c1088} -> !stream.file => !stream.timepoint + return +} diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir index c1c0920dbfb2..c42ccfee0a14 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/resource_ops.mlir @@ -40,17 +40,6 @@ func.func @resourceSize(%arg0: !stream.resource<*>) -> index { // ----- -// CHECK-LABEL: @resourceMap -func.func @resourceMap(%arg0: !util.buffer) -> !stream.resource { - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - // CHECK: = stream.resource.map %arg0[%c0] : !util.buffer -> !stream.resource{%c128} - %0 = stream.resource.map %arg0[%c0] : !util.buffer -> !stream.resource{%c128} - return %0 : !stream.resource -} - -// ----- - // CHECK-LABEL: @resourceTryMap func.func @resourceTryMap(%arg0: !util.buffer) -> (i1, !stream.resource) { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp index b1fe2eb62d5c..74ad0dd93266 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp @@ -237,196 +237,83 @@ computePackingMap(ArrayRef slices, // Upload materialization //===----------------------------------------------------------------------===// -struct AllocatedStorage { - // Resource storing all packed constants. +struct TimepointResource { + Value timepoint; Value resource; - // Total size, in bytes, of the storage resource. Value resourceSize; }; -struct UploadResult { - // Timepoint when the storage is initialized with the constant values. - Value timepoint; - // Each (resource, resourceSize) allocated. - SmallVector allocations; -}; - -// Maps constants as a staging buffer and then issues copy commands. -// Per-storage resource we map the source rodata, allocate the result, and then -// issue an async copy from source to result. To avoid a bunch of overhead when -// there are multiple storage buffers we invert the logic so that we put all the -// async copies into a single region. -static UploadResult -buildStagingUpload(Location loc, IREE::Stream::AffinityAttr affinityAttr, - IREE::Stream::ResourceType resourceType, - ArrayRef storageResources, - ArrayRef storageBuffers, IndexSet &indexSet, - OpBuilder &builder) { - UploadResult uploadResult; - auto stagingType = builder.getType( - IREE::Stream::Lifetime::Staging); - - // Map all of the storage data and allocate the result buffers. - // This will produce a list of copies we should perform from staging->final. - struct Copy { - Location loc; - Value source; - Value sourceSize; - Value sourceOffset; - Value target; - Value targetSize; - Value targetOffset; - Value length; - }; - SmallVector copies; - SmallVector capturedResources; - SmallVector capturedResourceSizes; - for (auto [storageResource, storageBuffer] : - llvm::zip_equal(storageResources, storageBuffers)) { - // Today we assume 1:1 lengths of storage data and uploaded data, but this - // need not be the case if we want to pad the buffer for runtime. - auto totalLength = indexSet.get(storageResource.totalSize); - - // Map the source staging resource rodata. - auto mapOp = builder.create( - storageResource.loc, stagingType, storageBuffer, indexSet.get(0), - totalLength, affinityAttr); - - // Allocate the resulting storage resource of the final resource type. - auto allocOp = builder.create( - storageResource.loc, resourceType, mapOp.getResultSize(), - /*uninitialized=*/builder.getUnitAttr(), affinityAttr); - - uploadResult.allocations.push_back({ - allocOp.getResults().front(), - allocOp.getStorageSizes().front(), - }); - - // Queue copy for processing below. - Copy copy{ - storageResource.loc, - mapOp.getResult(), - mapOp.getResultSize(), - indexSet.get(0), - allocOp.getResults().front(), - allocOp.getStorageSizes().front(), - indexSet.get(0), - totalLength, - }; - capturedResources.push_back(copy.source); - capturedResourceSizes.push_back(copy.sourceSize); - capturedResources.push_back(copy.target); - capturedResourceSizes.push_back(copy.targetSize); - copies.push_back(std::move(copy)); - } - - // Create the execution op capturing the resources. - auto executeOp = builder.create( - loc, /*awaitTimepoint=*/Value{}, capturedResources, - capturedResourceSizes); - if (affinityAttr) - executeOp.setAffinityAttr(affinityAttr); - uploadResult.timepoint = executeOp.getResultTimepoint(); - - // Map captured resources into the execution region. - IRMapping mapping; - auto *entryBlock = new Block(); - executeOp.getBody().push_back(entryBlock); - for (auto outerValue : capturedResources) { - auto arg = - entryBlock->addArgument(outerValue.getType(), outerValue.getLoc()); - mapping.map(outerValue, arg); - } - - // Issue copies. Note that we use the captured resources. - auto executionBuilder = OpBuilder::atBlockBegin(entryBlock); - for (auto © : copies) { - executionBuilder.create( - copy.loc, mapping.lookup(copy.source), copy.sourceSize, - copy.sourceOffset, mapping.lookup(copy.target), copy.targetSize, - copy.targetOffset, copy.length); - } - executionBuilder.create(executeOp.getLoc()); - - return uploadResult; +static TimepointResource buildFileRead( + Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr, + IREE::Stream::ResourceType resourceType, StorageResource storageResource, + Value storageResourceSize, Value storageBuffer, Value storageBufferSize, + IndexSet &indexSet, OpBuilder &builder) { + // Allocate the resulting storage resource of the final resource type. + auto allocOp = builder.create( + storageResource.loc, resourceType, storageResourceSize, + /*uninitialized=*/builder.getUnitAttr(), affinityAttr); + + // Create the file backed by the constant resource buffer. + auto fileOp = builder.create( + storageResource.loc, storageBuffer, storageBufferSize, indexSet.get(0), + storageResourceSize, affinityAttr); + + // Issue asynchronous file read into the buffer. + auto zeroI64 = + builder.create(storageResource.loc, 0, 64); + auto readOp = builder.create( + storageResource.loc, fileOp.getResult(), zeroI64, allocOp.getResult(0), + allocOp.getResultSize(0), indexSet.get(0), storageResourceSize, + awaitTimepoint, affinityAttr); + + return TimepointResource{readOp.getResultTimepoint(), readOp.getTarget(), + readOp.getTargetSize()}; } -// Emits IR to first try mapping the storage resources directly into usable -// constant resources. If the mapping fails (the target can't use the memory) +// Emits IR to first try mapping the storage resource directly into a usable +// constant resource. If the mapping fails (the target can't use the memory) // then fall back to staging uploads. -static UploadResult buildTryMapConstantResources( - Location loc, IREE::Stream::AffinityAttr affinityAttr, - IREE::Stream::ResourceType resourceType, - ArrayRef storageResources, ArrayRef storageBuffers, +// Returns a timepoint indicating the operation has completed. +static TimepointResource buildTryMapConstantResource( + Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr, + IREE::Stream::ResourceType resourceType, StorageResource storageResource, + Value storageResourceSize, Value storageBuffer, Value storageBufferSize, IndexSet &indexSet, OpBuilder &builder) { - // Try mapping each resource. We do this as an all-or-nothing across the - // storage: if any fails we fallback to the allocation path. This is mostly - // just to get more predictable behavior in the face of weird platform - // requirements: we want something like misaligned mappings to be easily - // visible in tracing. - SmallVector mappedResources; - SmallVector resultTypes; - Value ok; - auto zero = indexSet.get(0); - for (auto [storageResource, storageBuffer] : - llvm::zip_equal(storageResources, storageBuffers)) { - auto tryMapOp = builder.create( - storageResource.loc, builder.getI1Type(), resourceType, storageBuffer, - zero, indexSet.get(storageResource.totalSize), affinityAttr); - if (!ok) { - ok = tryMapOp.getDidMap(); - } else { - ok = builder.createOrFold(tryMapOp.getLoc(), ok, - tryMapOp.getDidMap()); - } - mappedResources.push_back(tryMapOp.getResult()); - resultTypes.push_back(tryMapOp.getResult().getType()); - } + // Try mapping; this may fail if the device can't use the storage buffer as + // the type of resource requested. + auto tryMapOp = builder.create( + storageResource.loc, builder.getI1Type(), resourceType, storageBuffer, + indexSet.get(0), storageResourceSize, affinityAttr); // If we are able to directly map the resources then we don't need to wait. - auto timepointType = builder.getType(); - resultTypes.push_back(timepointType); - - // if ok: return mapped resources - // else: allocate and upload + // Otherwise we need to stage the storage buffer into memory via the file + // streaming API. auto ifOp = builder.create( - loc, ok, + loc, tryMapOp.getDidMap(), [&](OpBuilder &thenBuilder, Location loc) { // Just return the resources + an immediate timepoint. - SmallVector ifResults = mappedResources; - ifResults.push_back( - thenBuilder.create(loc)); - thenBuilder.create(loc, ifResults); + thenBuilder.create(loc, ValueRange{ + awaitTimepoint, + tryMapOp.getResult(), + }); }, [&](OpBuilder &elseBuilder, Location loc) { - // Fallback to upload and then - auto stagingResult = buildStagingUpload( - loc, affinityAttr, resourceType, storageResources, storageBuffers, - indexSet, builder); - SmallVector ifResults; - for (auto &allocation : stagingResult.allocations) { - ifResults.push_back(allocation.resource); - } - ifResults.push_back(stagingResult.timepoint); - elseBuilder.create(loc, ifResults); + auto readResult = + buildFileRead(loc, awaitTimepoint, affinityAttr, resourceType, + storageResource, storageResourceSize, storageBuffer, + storageBufferSize, indexSet, elseBuilder); + elseBuilder.create(loc, ValueRange{ + readResult.timepoint, + readResult.resource, + }); }); - auto ifTimepoint = ifOp.getResults().back(); - auto ifResources = ifOp.getResults().slice(0, ifOp.getResults().size() - 1); - - // Use the result of either the direct mapping or the staging upload. - UploadResult uploadResult; - uploadResult.timepoint = ifTimepoint; - for (auto [storageResource, ifResource] : - llvm::zip_equal(storageResources, ifResources)) { - uploadResult.allocations.push_back({ - ifResource, - indexSet.get(storageResource.totalSize), - }); - } - return uploadResult; + auto ifTimepoint = ifOp.getResults().front(); + auto ifResource = ifOp.getResults().back(); + return TimepointResource{ifTimepoint, ifResource, storageResourceSize}; } -static Value generateUpload(IREE::Stream::ResourceConstantsOp constantsOp, +static Value generateUpload(Value awaitTimepoint, + IREE::Stream::ResourceConstantsOp constantsOp, IREE::Stream::Lifetime lifetime, IREE::Stream::ResourceConfigAttr resourceConfig, IndexSet &indexSet, OpBuilder &builder) { @@ -454,49 +341,53 @@ static Value generateUpload(IREE::Stream::ResourceConstantsOp constantsOp, if (storageResources.empty()) return nullptr; + // TODO(benvanik): should be able to have a single buffer constant and + // subrange it so that we don't need so many files. + + auto anyResult = slices.front().result; + auto resourceType = + llvm::cast(anyResult.getType()); + // Emit rodata storage for the constant values. // As our upload paths may vary this ensures that we are only emitting // them once regardless of how many strategies we emit IR for. - SmallVector storageBuffers; + Value currentTimepoint = awaitTimepoint; for (auto &storageResource : storageResources) { - auto rodataOp = builder.create( + Value storageBuffer = builder.create( storageResource.loc, /*name=*/nullptr, storageResource.data, builder.getIndexAttr(resourceConfig.getMinBufferOffsetAlignment()), /*mimeType=*/nullptr); - storageBuffers.push_back(rodataOp); - } - - // If this is producing constants (vs variables) we can try to go on a - // fast-path where we directly map the constant memory. If producing - // variables then we always need to stage and clone. - auto anyResult = slices.front().result; - auto resourceType = - llvm::cast(anyResult.getType()); - UploadResult uploadResult; - if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { - uploadResult = buildTryMapConstantResources( - constantsOp.getLoc(), constantsOp.getAffinityAttr(), resourceType, - storageResources, storageBuffers, indexSet, builder); - } else { - uploadResult = buildStagingUpload( - constantsOp.getLoc(), constantsOp.getAffinityAttr(), resourceType, - storageResources, storageBuffers, indexSet, builder); - } + auto resourceSize = indexSet.get(storageResource.totalSize); + + // If this is producing constants (vs variables) we can try to go on a + // fast-path where we directly map the constant memory. If producing + // variables then we always need to stage and clone. + TimepointResource uploadedResource; + if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { + uploadedResource = buildTryMapConstantResource( + constantsOp.getLoc(), currentTimepoint, constantsOp.getAffinityAttr(), + resourceType, storageResource, resourceSize, storageBuffer, + resourceSize, indexSet, builder); + } else { + uploadedResource = buildFileRead( + constantsOp.getLoc(), currentTimepoint, constantsOp.getAffinityAttr(), + resourceType, storageResource, resourceSize, storageBuffer, + resourceSize, indexSet, builder); + } - // Build subviews for all packed spans back into storage buffers. - for (auto [storageResource, allocatedStorage] : - llvm::zip_equal(storageResources, uploadResult.allocations)) { for (auto &span : storageResource.spans) { auto loc = span.slice.result.getLoc(); auto subviewOp = builder.create( - loc, allocatedStorage.resource, allocatedStorage.resourceSize, + loc, uploadedResource.resource, uploadedResource.resourceSize, indexSet.get(span.offset), span.slice.resultSize); span.slice.result.replaceAllUsesWith(subviewOp.getResult()); } + + currentTimepoint = uploadedResource.timepoint; } // Join on storage timepoints for our transitive dependencies to await. - return uploadResult.timepoint; + return currentTimepoint; } //===----------------------------------------------------------------------===// @@ -540,31 +431,20 @@ class PackConstantsPass : public PackConstantsBase { indexSet.populate(constantsOp.getResultSizes()); // Perform upload/processing for immutable and mutable constants. - SmallVector timepoints; - if (auto timepoint = - generateUpload(constantsOp, IREE::Stream::Lifetime::Constant, - resourceConfig, indexSet, builder)) { - timepoints.push_back(timepoint); - } - if (auto timepoint = - generateUpload(constantsOp, IREE::Stream::Lifetime::Variable, - resourceConfig, indexSet, builder)) { - timepoints.push_back(timepoint); + Value currentTimepoint = + builder.create( + constantsOp.getLoc()); + if (auto uploadTimepoint = generateUpload( + currentTimepoint, constantsOp, IREE::Stream::Lifetime::Constant, + resourceConfig, indexSet, builder)) { + currentTimepoint = uploadTimepoint; } - if (timepoints.empty()) - return; - - // Join on storage timepoints for our transitive dependencies to await. - // We could do this at a finer granularity if we were to split the - // constants op into multiple units earlier on. - Value joinTimepoint; - if (timepoints.size() > 1) { - joinTimepoint = builder.create( - constantsOp.getLoc(), timepoints.front().getType(), timepoints); - } else { - joinTimepoint = timepoints.front(); + if (auto uploadTimepoint = generateUpload( + currentTimepoint, constantsOp, IREE::Stream::Lifetime::Variable, + resourceConfig, indexSet, builder)) { + currentTimepoint = uploadTimepoint; } - constantsOp.getResultTimepoint().replaceAllUsesWith(joinTimepoint); + constantsOp.getResultTimepoint().replaceAllUsesWith(currentTimepoint); constantsOp.erase(); }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir index ce9fcd2b04a6..0146b930bad5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/dump_statistics.mlir @@ -5,19 +5,19 @@ // CHECK-PRETTY: Constants: 1, estimated storage of 192 B // CHECK-PRETTY: Variables: 0, (TBD) // CHECK-PRETTY: D->H Syncs: 2 -// CHECK-PRETTY: Submissions: 3, using cumulative 0 B +// CHECK-PRETTY: Submissions: 2, using cumulative 0 B // CHECK-PRETTY: DMA Fills: 0 -// CHECK-PRETTY: DMA Copies: 2 +// CHECK-PRETTY: DMA Copies: 1 // CHECK-PRETTY: Collectives: 0 // CHECK-PRETTY: Dispatches: 3 // CHECK-PRETTY: Executables: 2, 33% reuse // CHECK-CSV: ; Aggregate Statistics // CHECK-CSV: "Constants","Constant Size","Variables","Variable Size","Awaits","Submissions","Transient Size","Fills","Copies","Dispatches","Async Calls","Executables" -// CHECK-CSV: 1,192,0,0,2,3,0,0,2,3,0,2 +// CHECK-CSV: 1,192,0,0,2,2,0,0,1,3,0,2 // CHECK-CSV: ; Execution // CHECK-CSV: "Depth","Command","Symbol","Length","Invocations","Workload","Operands","Resources" -// CHECK-CSV: 0,"copy",,192,,,, +// CHECK-CSV: 0,"copy",,16,,,, // CHECK-CSV: 0,"dispatch","@func_a_ex_0::@dispatch_0",,4,"4;1;1",0,3 util.global private mutable @_constant__timepoint = #stream.timepoint @@ -41,18 +41,8 @@ util.initializer { dense<0> : vector<20xi8>, ]> %did_map, %result = stream.resource.try_map %1[%c0] : !util.buffer -> i1, !stream.resource{%c192} - %2:2 = scf.if %did_map -> (!stream.resource, !stream.timepoint) { - scf.yield %result, %0 : !stream.resource, !stream.timepoint - } else { - %3 = stream.resource.map %1[%c0] : !util.buffer -> !stream.resource{%c192} - %4 = stream.resource.alloc uninitialized : !stream.resource{%c192} - %5 = stream.cmd.execute with(%3 as %arg0: !stream.resource{%c192}, %4 as %arg1: !stream.resource{%c192}) { - stream.cmd.copy %arg0[%c0], %arg1[%c0], %c192 : !stream.resource{%c192} -> !stream.resource{%c192} - } => !stream.timepoint - scf.yield %4, %5 : !stream.resource, !stream.timepoint - } - util.global.store %2#0, @_constant : !stream.resource - util.global.store %2#1, @_constant__timepoint : !stream.timepoint + util.global.store %result, @_constant : !stream.resource + util.global.store %0, @_constant__timepoint : !stream.timepoint util.initializer.return } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir index f03ada6f34c9..2b5adc304f78 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_constants.mlir @@ -19,6 +19,8 @@ func.func @resourceConstants() -> (!stream.resource, !stream.resource< %c8 = arith.constant 8 : index %c48 = arith.constant 48 : index + // CHECK-DAG: %[[IMMEDIATE:.+]] = stream.timepoint.immediate => !stream.timepoint + // Fetch the read-only host data containing the constants. // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : index} : !util.buffer = #composite_of_192b %0:4 = stream.resource.constants : @@ -31,32 +33,59 @@ func.func @resourceConstants() -> (!stream.resource, !stream.resource< // succeeds we are done and can avoid allocation/complete immediately. // CHECK: %[[DID_MAP:.+]], %[[TRY_MAP:.+]] = stream.resource.try_map %[[RODATA]][%c0] : // CHECK-SAME: !util.buffer -> i1, !stream.resource{%c192} - // CHECK: %[[IF:.+]]:2 = scf.if %[[DID_MAP]] -> (!stream.resource, !stream.timepoint) { - // CHECK-NEXT: %[[IMMEDIATE:.+]] = stream.timepoint.immediate => !stream.timepoint - // CHECK-NEXT: scf.yield %[[TRY_MAP]], %[[IMMEDIATE]] + // CHECK: %[[IF:.+]]:2 = scf.if %[[DID_MAP]] -> (!stream.timepoint, !stream.resource) { + // CHECK-NEXT: scf.yield %[[IMMEDIATE]], %[[TRY_MAP]] // CHECK-NEXT: } else { // If the mapping fails we need to perform an upload via a staging buffer. - // CHECK: %[[STAGING:.+]] = stream.resource.map %[[RODATA]][%c0] : !util.buffer -> !stream.resource{%c192} // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%c192} - // CHECK: %[[EXEC_TIMEPOINT:.+]] = stream.cmd.execute - // CHECK-SAME: with(%[[STAGING]] as %[[STAGING_CAPTURE:.+]]: !stream.resource{%c192}, - // CHECK-SAME: %[[ALLOC]] as %[[ALLOC_CAPTURE:.+]]: !stream.resource{%c192}) { - // CHECK: stream.cmd.copy %[[STAGING_CAPTURE]][%c0], %[[ALLOC_CAPTURE]][%c0], %c192 : !stream.resource{%c192} -> !stream.resource{%c192} - // CHECK: } => !stream.timepoint - // CHECK: scf.yield %[[ALLOC]], %[[EXEC_TIMEPOINT]] + // CHECK: %[[FILE:.+]] = stream.file.constant %[[RODATA]][%c0 for %c192] : !util.buffer{%c192} -> !stream.file + // CHECK: %[[READ_TIMEPOINT:.+]] = stream.file.read await(%[[IMMEDIATE]]) => %[[FILE]][%c0_i64], %[[ALLOC]][%c0], %c192 : !stream.file -> !stream.resource{%c192} => !stream.timepoint + // CHECK: scf.yield %[[READ_TIMEPOINT]], %[[ALLOC]] // Get subviews pointing to the subresources within the packed resource. - // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF]]#0[%c0] : !stream.resource{%c192} -> !stream.resource{%c4} - // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF]]#0[%c64] : !stream.resource{%c192} -> !stream.resource{%c8} - // CHECK: %[[RES2:.+]] = stream.resource.subview %[[IF]]#0[%c128] : !stream.resource{%c192} -> !stream.resource{%c48} + // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF]]#1[%c0] : !stream.resource{%c192} -> !stream.resource{%c4} + // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF]]#1[%c64] : !stream.resource{%c192} -> !stream.resource{%c8} + // CHECK: %[[RES2:.+]] = stream.resource.subview %[[IF]]#1[%c128] : !stream.resource{%c192} -> !stream.resource{%c48} - // CHECK: return %[[RES0]], %[[RES1]], %[[RES2]], %[[IF]]#1 + // CHECK: return %[[RES0]], %[[RES1]], %[[RES2]], %[[IF]]#0 return %0#0, %0#1, %0#2, %0#3 : !stream.resource, !stream.resource, !stream.resource, !stream.timepoint } // ----- +// Tests variables which always need copies so that they can be mutated. + +// CHECK: #composite_of_1088b = #util.composite<1088xi8, [ +// CHECK: dense<100> : tensor<256xi32>, +// CHECK: dense<[101, 102]> : tensor<2xi32>, +// CHECK: dense<0> : vector<56xi8>, +// CHECK: ]> + +// CHECK-LABEL: @resourceVariables +func.func @resourceVariables() -> (!stream.resource, !stream.resource, !stream.timepoint) { + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + + // CHECK-DAG: %[[IMMEDIATE:.+]] = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[RODATA:.+]] = util.buffer.constant {alignment = 64 : index} : !util.buffer = #composite_of_1088b + // CHECK: %[[ALLOC:.+]] = stream.resource.alloc uninitialized : !stream.resource{%c1088} + // CHECK: %[[FILE:.+]] = stream.file.constant %[[RODATA]][%c0 for %c1088] : !util.buffer{%c1088} -> !stream.file + // CHECK: %[[READ_TIMEPOINT:.+]] = stream.file.read await(%[[IMMEDIATE]]) => %[[FILE]][%c0_i64], %[[ALLOC]][%c0], %c1088 : !stream.file -> !stream.resource{%c1088} => !stream.timepoint + // CHECK: %[[RES0:.+]] = stream.resource.subview %[[ALLOC]][%c0] : !stream.resource{%c1088} -> !stream.resource{%c1024} + // CHECK: %[[RES1:.+]] = stream.resource.subview %[[ALLOC]][%c1024] : !stream.resource{%c1088} -> !stream.resource{%c8} + + %0:3 = stream.resource.constants : + !stream.resource{%c1024} = dense<100> : tensor<256xi32>, + !stream.resource{%c8} = dense<[101, 102]> : tensor<2xi32> + => !stream.timepoint + + // CHECK: return %[[RES0]], %[[RES1]], %[[READ_TIMEPOINT]] + return %0#0, %0#1, %0#2 : !stream.resource, !stream.resource, !stream.timepoint +} + +// ----- + // Tests that if we exceed the maximum allowed allocation size the constants get // partitioned into multiple buckets each within the required bounds. This test // produces the same logic as above but doubled. @@ -84,65 +113,27 @@ func.func @splitResourceConstants() -> (!stream.resource, !stream.reso %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index + // CHECK-DAG: %[[IMMEDIATE:.+]] = stream.timepoint.immediate => !stream.timepoint + // CHECK: %[[RODATA0:.+]] = util.buffer.constant {alignment = 16 : index} : !util.buffer = #composite_of_16b + // CHECK: %[[DID_MAP0:.+]], %[[TRY_MAP0:.+]] = stream.resource.try_map %[[RODATA0]] + // CHECK: %[[IF0:.+]]:2 = scf.if %[[DID_MAP0]] + // CHECK: %[[FILE0:.+]] = stream.file.constant %[[RODATA0]] + // CHECK: stream.file.read await(%[[IMMEDIATE]]) => %[[FILE0]] + // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF0]]#1[%c0] : !stream.resource{%c16} -> !stream.resource{%c4} + // CHECK: %[[RODATA1:.+]] = util.buffer.constant {alignment = 16 : index} : !util.buffer = #composite_of_16b1 + // CHECK: %[[DID_MAP1:.+]], %[[TRY_MAP1:.+]] = stream.resource.try_map %[[RODATA1]] + // CHECK: %[[IF1:.+]]:2 = scf.if %[[DID_MAP1]] + // CHECK: %[[FILE1:.+]] = stream.file.constant %[[RODATA1]] + // CHECK: stream.file.read await(%[[IF0]]#0) => %[[FILE1]] + // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF1]]#1[%c0] : !stream.resource{%c16} -> !stream.resource{%c8} + %0:3 = stream.resource.constants : !stream.resource{%c4} = dense<100> : tensor<1xi32>, !stream.resource{%c8} = dense<[101, 102]> : tensor<2xi32> => !stream.timepoint - // NOTE: we fall back for all even if only one fails; this is just for - // simplicity in the pass today but we could only fallback for the ones that - // failed if we wanted. - // CHECK: %[[DID_MAP0:.+]], %[[TRY_MAP0:.+]] = stream.resource.try_map %[[RODATA0]][%c0] : !util.buffer -> i1, !stream.resource{%c16} - // CHECK: %[[DID_MAP1:.+]], %[[TRY_MAP1:.+]] = stream.resource.try_map %[[RODATA1]][%c0] : !util.buffer -> i1, !stream.resource{%c16} - // CHECK: %[[BOTH_MAPPED:.+]] = arith.andi %[[DID_MAP0]], %[[DID_MAP1]] : i1 - // CHECK: %[[IF:.+]]:3 = scf.if %[[BOTH_MAPPED]] - // CHECK: scf.yield %[[TRY_MAP0]], %[[TRY_MAP1]] - // CHECK: } else { - - // CHECK: stream.resource.map %[[RODATA0]] - // CHECK: stream.resource.alloc - // CHECK: stream.resource.map %[[RODATA1]] - // CHECK: stream.resource.alloc - // CHECK: stream.cmd.execute - // CHECK-NEXT: stream.cmd.copy - // CHECK-NEXT: stream.cmd.copy - - // CHECK: %[[RES0:.+]] = stream.resource.subview %[[IF]]#0[%c0] : !stream.resource{%c16} -> !stream.resource{%c4} - // CHECK: %[[RES1:.+]] = stream.resource.subview %[[IF]]#1[%c0] : !stream.resource{%c16} -> !stream.resource{%c8} - - // CHECK: return %[[RES0]], %[[RES1]], %[[IF]]#2 + // CHECK: return %[[RES0]], %[[RES1]], %[[IF1]]#0 return %0#0, %0#1, %0#2 : !stream.resource, !stream.resource, !stream.timepoint } - -// ----- - -// Tests that resources with varying lifetimes get split and processed -// independently. This allows for fast-path constants while allowing variable -// initializers to go the normal staging route. We expect to end up with two -// constant storage buffers, two uploads, and a join for the final timepoint. - -// CHECK-LABEL: @mixedResourceConstants -func.func @mixedResourceConstants() -> (!stream.resource, !stream.resource, !stream.timepoint) { - %c8 = arith.constant 8 : index - %c1024 = arith.constant 1024 : index - - // CHECK: %[[CONSTANT_HOST:.+]] = util.buffer.constant {{.+}} = #composite_of_1024b - // CHECK: %[[CONSTANT_IF:.+]]:2 = scf.if {{.+}} -> (!stream.resource, !stream.timepoint) - // CHECK: %[[CONSTANT_VIEW:.+]] = stream.resource.subview %[[CONSTANT_IF]]#0 - - // CHECK: %[[VARIABLE_HOST:.+]] = util.buffer.constant {{.+}} = #composite_of_64b - // CHECK: %[[VARIABLE_BUFFER:.+]] = stream.resource.alloc {{.+}} : !stream.resource{%c64} - // CHECK: %[[VARIABLE_EXEC:.+]] = stream.cmd.execute {{.+}} %[[VARIABLE_BUFFER]] - // CHECK: %[[VARIABLE_VIEW:.+]] = stream.resource.subview %[[VARIABLE_BUFFER]] - - // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[CONSTANT_IF]]#1, %[[VARIABLE_EXEC]]) - %0:3 = stream.resource.constants : - !stream.resource{%c1024} = dense<100> : tensor<256xi32>, - !stream.resource{%c8} = dense<[101, 102]> : tensor<2xi32> - => !stream.timepoint - - // CHECK: return %[[CONSTANT_VIEW]], %[[VARIABLE_VIEW]], %[[JOIN]] - return %0#0, %0#1, %0#2 : !stream.resource, !stream.resource, !stream.timepoint -} diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index be9bb48c9c6d..44b9d2af29df 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td @@ -1066,7 +1066,7 @@ def Util_SubrangeOperandOpInterface : OpInterface<"SubrangeOperandOpInterface"> /*desc=*/[{ Returns the subrange operand values for the given flat operand index. }], - /*retTy=*/"SubrangeOperand", + /*retTy=*/"IREE::Util::SubrangeOperand", /*methodName=*/"getSubrangeOperand", /*args=*/(ins "unsigned":$operandIndex) >, @@ -1076,7 +1076,10 @@ def Util_SubrangeOperandOpInterface : OpInterface<"SubrangeOperandOpInterface"> }], /*retTy=*/"void", /*methodName=*/"setSubrangeOperand", - /*args=*/(ins "unsigned":$operandIndex, "SubrangeOperand":$operand) + /*args=*/(ins + "unsigned":$operandIndex, + "IREE::Util::SubrangeOperand":$operand + ) >, ]; } diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp index 24226886818c..a90f9816a27c 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp @@ -137,23 +137,6 @@ struct ResourceSizeOpPattern } }; -// The staging buffer returned from this is always a !util.buffer. -// We can thus directly pass along the input buffer that's being mapped -// (after taking a subspan for the defined range). -struct ResourceMapOpPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IREE::Stream::ResourceMapOp mapOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - mapOp, adaptor.getSource(), - getResourceSize(mapOp.getLoc(), adaptor.getSource(), rewriter), - adaptor.getSourceOffset(), adaptor.getResultSize()); - return success(); - } -}; - // The constant buffer returned from this is always a !util.buffer. // We can thus directly pass along the input buffer that's being mapped // (after taking a subspan for the defined range). @@ -703,9 +686,9 @@ void populateStreamToHALInlinePatterns(MLIRContext *context, patterns.insert(typeConverter, context); + ResourceTryMapOpPattern, ResourceLoadOpPattern, + ResourceStoreOpPattern, ResourceSubviewOpPattern>( + typeConverter, context); patterns.insert) -> index { // ----- -// CHECK-LABEL: @resourceMap -// CHECK-SAME: (%[[SOURCE:.+]]: !util.buffer) -func.func @resourceMap(%source: !util.buffer) -> !stream.resource { - // CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100 - %offset = arith.constant 100 : index - // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 128 - %length = arith.constant 128 : index - // CHECK: %[[SOURCE_SIZE:.+]] = util.buffer.size %[[SOURCE]] : !util.buffer - // CHECK: %[[MAPPING:.+]] = util.buffer.subspan %[[SOURCE]][%[[OFFSET]]] : !util.buffer{%[[SOURCE_SIZE]]} -> !util.buffer{%[[LENGTH]]} - %mapping = stream.resource.map %source[%offset] : !util.buffer -> !stream.resource{%length} - // CHECK: return %[[MAPPING]] - return %mapping : !stream.resource -} - -// ----- - // CHECK-LABEL: @resourceTryMap // CHECK-SAME: (%[[SOURCE:.+]]: !util.buffer) func.func @resourceTryMap(%source: !util.buffer) -> (i1, !stream.resource) { diff --git a/experimental/cuda2/CMakeLists.txt b/experimental/cuda2/CMakeLists.txt index 638c4f34b1eb..b1e8a9c7c001 100644 --- a/experimental/cuda2/CMakeLists.txt +++ b/experimental/cuda2/CMakeLists.txt @@ -56,6 +56,8 @@ iree_cc_library( iree::hal iree::hal::utils::buffer_transfer iree::hal::utils::collective_batch + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::cuda_executable_def_c_fbs diff --git a/experimental/cuda2/cuda_allocator.c b/experimental/cuda2/cuda_allocator.c index e119529525ec..76ea1460d2e5 100644 --- a/experimental/cuda2/cuda_allocator.c +++ b/experimental/cuda2/cuda_allocator.c @@ -41,6 +41,9 @@ typedef struct iree_hal_cuda2_allocator_t { // between GPU and CPU if not. bool supports_concurrent_managed_access; + // Whether host memory can be registered with CU_MEMHOSTREGISTER_READ_ONLY. + bool supports_read_only_host_register; + IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) } iree_hal_cuda2_allocator_t; @@ -77,13 +80,27 @@ iree_status_t iree_hal_cuda2_allocator_create( &supports_concurrent_managed_access, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, device), "cuDeviceGetAttribute")); - IREE_TRACE_ZONE_APPEND_TEXT( z0, supports_concurrent_managed_access ? "has CONCURRENT_MANAGED_ACCESS" : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " "device-local + host-visible memory)"); + // We can only provide the CU_MEMHOSTREGISTER_READ_ONLY flag when importing + // host memory if it's supported. + int supports_read_only_host_register = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + CU_RESULT_TO_STATUS( + context->syms, + cuDeviceGetAttribute( + &supports_read_only_host_register, + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED, device), + "cuDeviceGetAttribute")); + IREE_TRACE_ZONE_APPEND_TEXT(z0, supports_read_only_host_register + ? "has READ_ONLY_HOST_REGISTER_SUPPORTED" + : "no READ_ONLY_HOST_REGISTER_SUPPORTED"); + iree_hal_cuda2_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( host_allocator, sizeof(*allocator), (void**)&allocator); @@ -98,6 +115,8 @@ iree_status_t iree_hal_cuda2_allocator_create( allocator->host_allocator = host_allocator; allocator->supports_concurrent_managed_access = supports_concurrent_managed_access != 0; + allocator->supports_read_only_host_register = + supports_read_only_host_register != 0; *out_allocator = (iree_hal_allocator_t*)allocator; } @@ -434,7 +453,7 @@ static iree_status_t iree_hal_cuda2_allocator_allocate_buffer( &allocator->statistics, compat_params.type, allocation_size)); *out_buffer = buffer; } else { - if (!buffer) { + if (!buffer && (device_ptr || host_ptr)) { iree_hal_cuda2_buffer_free(allocator->symbols, buffer_type, device_ptr, host_ptr); } else { @@ -549,16 +568,10 @@ static iree_status_t iree_hal_cuda2_allocator_import_buffer( } buffer_type = IREE_HAL_CUDA_BUFFER_TYPE_HOST_REGISTERED; host_ptr = external_buffer->handle.host_allocation.ptr; - uint32_t register_flags = 0; - if (compat_params.access == IREE_HAL_MEMORY_ACCESS_READ) { - register_flags = CU_MEMHOSTREGISTER_READ_ONLY; - } - if (iree_any_bit_set(compat_params.usage, - IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS | - IREE_HAL_BUFFER_USAGE_DISPATCH_UNIFORM_READ | - IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | - IREE_HAL_BUFFER_USAGE_DISPATCH_IMAGE)) { - register_flags = CU_MEMHOSTREGISTER_DEVICEMAP; + uint32_t register_flags = CU_MEMHOSTREGISTER_DEVICEMAP; + if (compat_params.access == IREE_HAL_MEMORY_ACCESS_READ && + allocator->supports_read_only_host_register) { + register_flags |= CU_MEMHOSTREGISTER_READ_ONLY; } status = IREE_CURESULT_TO_STATUS( allocator->symbols, @@ -599,7 +612,7 @@ static iree_status_t iree_hal_cuda2_allocator_import_buffer( if (iree_status_is_ok(status)) { *out_buffer = buffer; } else { - if (!buffer) { + if (!buffer && (device_ptr || host_ptr)) { iree_hal_cuda2_buffer_free(allocator->symbols, buffer_type, device_ptr, host_ptr); } else { diff --git a/experimental/cuda2/cuda_buffer.c b/experimental/cuda2/cuda_buffer.c index da9a9573a9f0..ff5a254e5d25 100644 --- a/experimental/cuda2/cuda_buffer.c +++ b/experimental/cuda2/cuda_buffer.c @@ -89,9 +89,11 @@ static iree_status_t iree_hal_cuda2_buffer_map_range( IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( iree_hal_buffer_memory_type(base_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(base_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(base_buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c index c6bc9890a35d..1263f57d3f52 100644 --- a/experimental/cuda2/cuda_device.c +++ b/experimental/cuda2/cuda_device.c @@ -30,6 +30,8 @@ #include "iree/base/internal/math.h" #include "iree/hal/utils/buffer_transfer.h" #include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" //===----------------------------------------------------------------------===// // iree_hal_cuda2_device_t @@ -546,6 +548,23 @@ static iree_status_t iree_hal_cuda2_device_create_executable_cache( device->host_allocator, out_executable_cache); } +static iree_status_t iree_hal_cuda2_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_cuda2_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -597,7 +616,8 @@ static iree_status_t iree_hal_cuda2_device_queue_alloca( // If pools are not supported we allocate a buffer as normal from whatever // allocator is set on the device. iree_status_t status = iree_ok_status(); - if (device->supports_memory_pools) { + if (device->supports_memory_pools && + !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda2_memory_pools_alloca( &device->memory_pools, device->dispatch_cu_stream, pool, params, allocation_size, out_buffer); @@ -650,6 +670,48 @@ static iree_status_t iree_hal_cuda2_device_queue_dealloca( return status; } +static iree_status_t iree_hal_cuda2_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_cuda2_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_cuda2_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -721,6 +783,7 @@ static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = { iree_hal_cuda2_device_create_descriptor_set_layout, .create_event = iree_hal_cuda2_device_create_event, .create_executable_cache = iree_hal_cuda2_device_create_executable_cache, + .import_file = iree_hal_cuda2_device_import_file, .create_pipeline_layout = iree_hal_cuda2_device_create_pipeline_layout, .create_semaphore = iree_hal_cuda2_device_create_semaphore, .query_semaphore_compatibility = @@ -728,6 +791,8 @@ static const iree_hal_device_vtable_t iree_hal_cuda2_device_vtable = { .transfer_range = iree_hal_device_submit_transfer_range_and_wait, .queue_alloca = iree_hal_cuda2_device_queue_alloca, .queue_dealloca = iree_hal_cuda2_device_queue_dealloca, + .queue_read = iree_hal_cuda2_device_queue_read, + .queue_write = iree_hal_cuda2_device_queue_write, .queue_execute = iree_hal_cuda2_device_queue_execute, .queue_flush = iree_hal_cuda2_device_queue_flush, .wait_semaphores = iree_hal_cuda2_device_wait_semaphores, diff --git a/experimental/cuda2/event_semaphore.c b/experimental/cuda2/event_semaphore.c index 00ef21523ea3..47efd65e5817 100644 --- a/experimental/cuda2/event_semaphore.c +++ b/experimental/cuda2/event_semaphore.c @@ -13,9 +13,6 @@ #include "iree/base/internal/synchronization.h" #include "iree/hal/utils/semaphore_base.h" -// Sentinel to indicate the semaphore has failed and an error status is set. -#define IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE UINT64_MAX - typedef struct iree_hal_cuda2_semaphore_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. @@ -38,7 +35,7 @@ typedef struct iree_hal_cuda2_semaphore_t { // than trying to make the entire structure lock-free. iree_slim_mutex_t mutex; - // Current signaled value. May be IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE to + // Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to // indicate that the semaphore has been signaled for failure and // |failure_status| contains the error. uint64_t current_value IREE_GUARDED_BY(mutex); @@ -114,7 +111,7 @@ static iree_status_t iree_hal_cuda2_semaphore_query( *out_value = semaphore->current_value; iree_status_t status = iree_ok_status(); - if (*out_value >= IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE) { + if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_status_clone(semaphore->failure_status); } @@ -180,14 +177,14 @@ static void iree_hal_cuda2_semaphore_fail(iree_hal_semaphore_t* base_semaphore, } // Signal to our failure sentinel value. - semaphore->current_value = IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE; + semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; semaphore->failure_status = status; iree_slim_mutex_unlock(&semaphore->mutex); // Notify timepoints - note that this must happen outside the lock. - iree_hal_semaphore_notify(&semaphore->base, - IREE_HAL_CUDA_SEMAPHORE_FAILURE_VALUE, status_code); + iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, + status_code); IREE_TRACE_ZONE_END(z0); } diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index 6c501b1c0fa0..d157437ec4fc 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt @@ -66,6 +66,8 @@ iree_cc_library( iree::base::internal::synchronization iree::hal iree::hal::utils::buffer_transfer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::semaphore_base iree::schemas::rocm_executable_def_c_fbs COPTS diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c index e6a915cd6911..3026e3d10f9c 100644 --- a/experimental/rocm/rocm_device.c +++ b/experimental/rocm/rocm_device.c @@ -21,6 +21,8 @@ #include "experimental/rocm/status_util.h" #include "iree/base/internal/arena.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" //===----------------------------------------------------------------------===// // iree_hal_rocm_device_t @@ -246,6 +248,23 @@ static iree_status_t iree_hal_rocm_device_create_executable_cache( &device->context_wrapper, identifier, out_executable_cache); } +static iree_status_t iree_hal_rocm_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_rocm_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -300,6 +319,48 @@ static iree_status_t iree_hal_rocm_device_queue_dealloca( return iree_ok_status(); } +static iree_status_t iree_hal_rocm_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_rocm_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_rocm_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -358,6 +419,7 @@ static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable = { iree_hal_rocm_device_create_descriptor_set_layout, .create_event = iree_hal_rocm_device_create_event, .create_executable_cache = iree_hal_rocm_device_create_executable_cache, + .import_file = iree_hal_rocm_device_import_file, .create_pipeline_layout = iree_hal_rocm_device_create_pipeline_layout, .create_semaphore = iree_hal_rocm_device_create_semaphore, .query_semaphore_compatibility = @@ -365,6 +427,8 @@ static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable = { .transfer_range = iree_hal_device_submit_transfer_range_and_wait, .queue_alloca = iree_hal_rocm_device_queue_alloca, .queue_dealloca = iree_hal_rocm_device_queue_dealloca, + .queue_read = iree_hal_rocm_device_queue_read, + .queue_write = iree_hal_rocm_device_queue_write, .queue_execute = iree_hal_rocm_device_queue_execute, .queue_flush = iree_hal_rocm_device_queue_flush, .wait_semaphores = iree_hal_rocm_device_wait_semaphores, diff --git a/experimental/webgpu/BUILD.bazel b/experimental/webgpu/BUILD.bazel index a64087d15e6a..4e802e6c1de9 100644 --- a/experimental/webgpu/BUILD.bazel +++ b/experimental/webgpu/BUILD.bazel @@ -53,6 +53,8 @@ iree_runtime_cc_library( "//runtime/src/iree/hal/drivers/webgpu/platform", "//runtime/src/iree/hal/drivers/webgpu/shaders", "//runtime/src/iree/hal/utils:buffer_transfer", + "//runtime/src/iree/hal/utils:file_transfer", + "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/schemas:wgsl_executable_def_c_fbs", "@webgpu_headers", ], diff --git a/experimental/webgpu/CMakeLists.txt b/experimental/webgpu/CMakeLists.txt index 0c32ff046e0a..2466974c184f 100644 --- a/experimental/webgpu/CMakeLists.txt +++ b/experimental/webgpu/CMakeLists.txt @@ -49,6 +49,8 @@ iree_cc_library( iree::experimental::webgpu::platform iree::experimental::webgpu::shaders iree::hal::utils::buffer_transfer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::schemas::wgsl_executable_def_c_fbs PUBLIC ) diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c index 84d0d499ca83..8e75282625ad 100644 --- a/experimental/webgpu/webgpu_device.c +++ b/experimental/webgpu/webgpu_device.c @@ -21,6 +21,8 @@ #include "experimental/webgpu/staging_buffer.h" #include "iree/base/internal/arena.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" //===----------------------------------------------------------------------===// // iree_hal_webgpu_device_t @@ -270,6 +272,23 @@ static iree_status_t iree_hal_webgpu_device_create_executable_cache( out_executable_cache); } +static iree_status_t iree_hal_webgpu_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_webgpu_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -324,6 +343,50 @@ static iree_status_t iree_hal_webgpu_device_queue_dealloca( return iree_ok_status(); } +static iree_status_t iree_hal_webgpu_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + // TODO: stash a loop on the device to allow for async streaming. + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_webgpu_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + // TODO: stash a loop on the device to allow for async streaming. + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_webgpu_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -391,6 +454,7 @@ const iree_hal_device_vtable_t iree_hal_webgpu_device_vtable = { iree_hal_webgpu_device_create_descriptor_set_layout, .create_event = iree_hal_webgpu_device_create_event, .create_executable_cache = iree_hal_webgpu_device_create_executable_cache, + .import_file = iree_hal_webgpu_device_import_file, .create_pipeline_layout = iree_hal_webgpu_device_create_pipeline_layout, .create_semaphore = iree_hal_webgpu_device_create_semaphore, .query_semaphore_compatibility = @@ -398,6 +462,8 @@ const iree_hal_device_vtable_t iree_hal_webgpu_device_vtable = { .transfer_range = iree_hal_device_submit_transfer_range_and_wait, .queue_alloca = iree_hal_webgpu_device_queue_alloca, .queue_dealloca = iree_hal_webgpu_device_queue_dealloca, + .queue_read = iree_hal_webgpu_device_queue_read, + .queue_write = iree_hal_webgpu_device_queue_write, .queue_execute = iree_hal_webgpu_device_queue_execute, .queue_flush = iree_hal_webgpu_device_queue_flush, .wait_semaphores = iree_hal_webgpu_device_wait_semaphores, diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py index 8067ae05cda3..3b423f176970 100644 --- a/runtime/bindings/python/tests/hal_test.py +++ b/runtime/bindings/python/tests/hal_test.py @@ -137,9 +137,10 @@ def testBufferViewConstructor(self): bv = iree.runtime.HalBufferView( buffer, (1, 2), iree.runtime.HalElementType.INT_16 ) + # NOTE: the exact bits set on type/usage/etc is implementation defined. self.assertEqual( repr(bv), - "", + "", ) def testBufferMap(self): @@ -158,9 +159,10 @@ def testAllocateBufferCopy(self): allowed_usage=iree.runtime.BufferUsage.DEFAULT, buffer=ary, ) + # NOTE: the exact bits set on type/usage/etc is implementation defined. self.assertEqual( repr(buffer), - "", + "", ) def testAllocateBufferViewCopy(self): @@ -171,16 +173,18 @@ def testAllocateBufferViewCopy(self): buffer=ary, element_type=iree.runtime.HalElementType.SINT_32, ) + # NOTE: the exact bits set on type/usage/etc is implementation defined. self.assertEqual( repr(buffer), - "", + "", ) def testAllocateHostStagingBufferCopy(self): buffer = self.allocator.allocate_host_staging_buffer_copy(np.int32(0)) + # NOTE: the exact bits set on type/usage/etc is implementation defined. self.assertEqual( repr(buffer), - "", + "", ) def testSemaphore(self): diff --git a/runtime/src/iree/base/loop.h b/runtime/src/iree/base/loop.h index 4036a47ccc64..fa897e9cc1d1 100644 --- a/runtime/src/iree/base/loop.h +++ b/runtime/src/iree/base/loop.h @@ -294,7 +294,7 @@ typedef struct iree_loop_dispatch_params_t { typedef struct iree_loop_wait_until_params_t { // Callback issued after the deadline has passed. iree_loop_callback_t callback; - // Minimum time to wait before issueing the callback. + // Minimum time to wait before issuing the callback. iree_time_t deadline_ns; } iree_loop_wait_until_params_t; diff --git a/runtime/src/iree/hal/BUILD.bazel b/runtime/src/iree/hal/BUILD.bazel index 5ab6fd832276..d0b5bec51a26 100644 --- a/runtime/src/iree/hal/BUILD.bazel +++ b/runtime/src/iree/hal/BUILD.bazel @@ -57,6 +57,8 @@ iree_runtime_cc_library( "executable_cache.h", "fence.c", "fence.h", + "file.c", + "file.h", "pipeline_layout.c", "pipeline_layout.h", "resource.h", diff --git a/runtime/src/iree/hal/CMakeLists.txt b/runtime/src/iree/hal/CMakeLists.txt index c0c668876238..0e3b7928d5b5 100644 --- a/runtime/src/iree/hal/CMakeLists.txt +++ b/runtime/src/iree/hal/CMakeLists.txt @@ -50,6 +50,8 @@ iree_cc_library( "executable_cache.h" "fence.c" "fence.h" + "file.c" + "file.h" "pipeline_layout.c" "pipeline_layout.h" "resource.h" diff --git a/runtime/src/iree/hal/allocator_heap.c b/runtime/src/iree/hal/allocator_heap.c index c8cb80d5b015..a8bef41e1636 100644 --- a/runtime/src/iree/hal/allocator_heap.c +++ b/runtime/src/iree/hal/allocator_heap.c @@ -167,8 +167,10 @@ iree_hal_heap_allocator_query_buffer_compatibility( // Host currently uses mapping to copy buffers, which is done a lot. // We could probably remove this mutation by preventing copies in those cases. // TODO(benvanik): check if transfer is still required for DMA copy source. - params->usage |= - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER; + params->usage |= IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED | + IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT | + IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM | + IREE_HAL_BUFFER_USAGE_TRANSFER; return compatibility; } diff --git a/runtime/src/iree/hal/api.h b/runtime/src/iree/hal/api.h index cfa5b4cbfdce..a377a78a87f1 100644 --- a/runtime/src/iree/hal/api.h +++ b/runtime/src/iree/hal/api.h @@ -23,6 +23,7 @@ #include "iree/hal/executable.h" // IWYU pragma: export #include "iree/hal/executable_cache.h" // IWYU pragma: export #include "iree/hal/fence.h" // IWYU pragma: export +#include "iree/hal/file.h" // IWYU pragma: export #include "iree/hal/pipeline_layout.h" // IWYU pragma: export #include "iree/hal/resource.h" // IWYU pragma: export #include "iree/hal/semaphore.h" // IWYU pragma: export diff --git a/runtime/src/iree/hal/buffer.c b/runtime/src/iree/hal/buffer.c index 7a682e73af64..3ec2049c910c 100644 --- a/runtime/src/iree/hal/buffer.c +++ b/runtime/src/iree/hal/buffer.c @@ -167,7 +167,7 @@ IREE_API_EXPORT iree_status_t iree_hal_subspan_buffer_create( } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } static void iree_hal_subspan_buffer_destroy(iree_hal_buffer_t* base_buffer) { @@ -833,9 +833,11 @@ IREE_API_EXPORT iree_status_t iree_hal_buffer_map_range( iree_hal_buffer_memory_type(buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + z0, iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); } iree_device_size_t local_byte_offset = 0; diff --git a/runtime/src/iree/hal/cts/CMakeLists.txt b/runtime/src/iree/hal/cts/CMakeLists.txt index a503ed8a3a9c..0033920db8e1 100644 --- a/runtime/src/iree/hal/cts/CMakeLists.txt +++ b/runtime/src/iree/hal/cts/CMakeLists.txt @@ -14,6 +14,7 @@ set(IREE_ALL_CTS_TESTS "driver" "event" "executable_cache" + "file" "pipeline_layout" "semaphore" "semaphore_submission" @@ -159,6 +160,18 @@ iree_cc_library( iree::testing::gtest ) +iree_cc_library( + NAME + file_test_library + HDRS + "file_test.h" + DEPS + ::cts_test_base + iree::base + iree::hal + iree::testing::gtest +) + iree_cc_library( NAME pipeline_layout_test_library diff --git a/runtime/src/iree/hal/cts/file_test.h b/runtime/src/iree/hal/cts/file_test.h new file mode 100644 index 000000000000..f98053d7d3ab --- /dev/null +++ b/runtime/src/iree/hal/cts/file_test.h @@ -0,0 +1,139 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_CTS_FILE_TEST_H_ +#define IREE_HAL_CTS_FILE_TEST_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/cts/cts_test_base.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace hal { +namespace cts { + +using ::testing::ContainerEq; + +namespace { +constexpr iree_device_size_t kMinimumAlignment = 128; +} // namespace + +class file_test : public CtsTestBase { + protected: + void CreatePatternedDeviceBuffer(iree_device_size_t buffer_size, + uint8_t pattern, + iree_hal_buffer_t** out_buffer) { + iree_hal_buffer_params_t params = {0}; + params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER; + params.min_alignment = kMinimumAlignment; + iree_hal_buffer_t* device_buffer = NULL; + IREE_CHECK_OK(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device_), params, buffer_size, + iree_const_byte_span_empty(), &device_buffer)); + + iree_hal_transfer_command_t transfer_command; + memset(&transfer_command, 0, sizeof(transfer_command)); + transfer_command.type = IREE_HAL_TRANSFER_COMMAND_TYPE_FILL; + transfer_command.fill.target_buffer = device_buffer; + transfer_command.fill.target_offset = 0; + transfer_command.fill.length = buffer_size; + transfer_command.fill.pattern = &pattern; + transfer_command.fill.pattern_length = sizeof(pattern); + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_CHECK_OK(iree_hal_create_transfer_command_buffer( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_QUEUE_AFFINITY_ANY, 1, &transfer_command, &command_buffer)); + IREE_CHECK_OK(SubmitCommandBufferAndWait(command_buffer)); + iree_hal_command_buffer_release(command_buffer); + + *out_buffer = device_buffer; + } + + void CreatePatternedMemoryFile(iree_hal_memory_access_t access, + iree_device_size_t file_size, uint8_t pattern, + iree_hal_file_t** out_file) { + void* file_contents = NULL; + IREE_CHECK_OK(iree_allocator_malloc_aligned(iree_allocator_system(), + file_size, kMinimumAlignment, 0, + (void**)&file_contents)); + memset(file_contents, pattern, file_size); + + iree_hal_external_file_t external_file; + memset(&external_file, 0, sizeof(external_file)); + external_file.type = IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION; + external_file.flags = 0; + external_file.handle.host_allocation = + iree_make_byte_span(file_contents, file_size); + iree_hal_file_release_callback_t release_callback = { + +[](void* user_data) { + iree_allocator_free_aligned(iree_allocator_system(), user_data); + }, + file_contents}; + IREE_CHECK_OK(iree_hal_file_import(device_, IREE_HAL_QUEUE_AFFINITY_ANY, + access, &external_file, release_callback, + out_file)); + } +}; + +// Reads the entire file into a buffer and check the contents match. +TEST_P(file_test, ReadEntireFile) { + iree_device_size_t file_size = 128; + iree_hal_file_t* file = NULL; + CreatePatternedMemoryFile(IREE_HAL_MEMORY_ACCESS_READ, file_size, 0xDEu, + &file); + iree_hal_buffer_t* buffer = NULL; + CreatePatternedDeviceBuffer(file_size, 0xCD, &buffer); + + iree_hal_semaphore_t* semaphore = NULL; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &semaphore)); + iree_hal_fence_t* wait_fence = NULL; + IREE_ASSERT_OK(iree_hal_fence_create_at( + semaphore, 1ull, iree_allocator_system(), &wait_fence)); + iree_hal_fence_t* signal_fence = NULL; + IREE_ASSERT_OK(iree_hal_fence_create_at( + semaphore, 2ull, iree_allocator_system(), &signal_fence)); + + // NOTE: synchronously executing here so start with the wait signaled. + // We should be able to make this async in the future. + IREE_ASSERT_OK(iree_hal_fence_signal(wait_fence)); + + IREE_ASSERT_OK(iree_hal_device_queue_read( + device_, IREE_HAL_QUEUE_AFFINITY_ANY, + iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), /*source_file=*/file, + /*source_offset=*/0, /*target_buffer=*/buffer, /*target_offset=*/0, + /*length=*/file_size, /*flags=*/0)); + + IREE_ASSERT_OK(iree_hal_fence_wait(signal_fence, iree_infinite_timeout())); + iree_hal_fence_release(wait_fence); + iree_hal_fence_release(signal_fence); + iree_hal_semaphore_release(semaphore); + + std::vector reference_buffer(file_size); + memset(reference_buffer.data(), 0xDEu, file_size); + std::vector actual_data(file_size); + IREE_ASSERT_OK(iree_hal_device_transfer_d2h( + device_, buffer, /*source_offset=*/0, + /*target_buffer=*/actual_data.data(), + /*data_length=*/file_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout())); + EXPECT_THAT(actual_data, ContainerEq(reference_buffer)); + + iree_hal_buffer_release(buffer); + iree_hal_file_release(file); +} + +} // namespace cts +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_CTS_FILE_TEST_H_ diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c index e4d46f40d09c..efd2252e208e 100644 --- a/runtime/src/iree/hal/device.c +++ b/runtime/src/iree/hal/device.c @@ -208,6 +208,103 @@ IREE_API_EXPORT iree_status_t iree_hal_device_queue_dealloca( return status; } +IREE_API_EXPORT iree_status_t iree_hal_device_queue_copy( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)length); + + // If we are starting execution immediately then we can reduce latency by + // allowing inline command buffer execution. + iree_hal_command_buffer_mode_t command_buffer_mode = + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT; + if (wait_semaphore_list.count == 0) { + command_buffer_mode |= IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION; + } + + iree_hal_transfer_command_t command = { + .type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY, + .copy = + { + .source_buffer = source_buffer, + .source_offset = source_offset, + .target_buffer = target_buffer, + .target_offset = target_offset, + .length = length, + }, + }; + + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_create_transfer_command_buffer(device, command_buffer_mode, + queue_affinity, 1, &command, + &command_buffer)); + + iree_status_t status = + iree_hal_device_queue_execute(device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, 1, &command_buffer); + + iree_hal_command_buffer_release(command_buffer); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_device_queue_read( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT( + !wait_semaphore_list.count || + (wait_semaphore_list.semaphores && wait_semaphore_list.payload_values)); + IREE_ASSERT_ARGUMENT(!signal_semaphore_list.count || + (signal_semaphore_list.semaphores && + signal_semaphore_list.payload_values)); + IREE_ASSERT_ARGUMENT(source_file); + IREE_ASSERT_ARGUMENT(target_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(device, queue_read)( + device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_device_queue_write( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT( + !wait_semaphore_list.count || + (wait_semaphore_list.semaphores && wait_semaphore_list.payload_values)); + IREE_ASSERT_ARGUMENT(!signal_semaphore_list.count || + (signal_semaphore_list.semaphores && + signal_semaphore_list.payload_values)); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_file); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(device, queue_write)( + device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags); + IREE_TRACE_ZONE_END(z0); + return status; +} + IREE_API_EXPORT iree_status_t iree_hal_device_queue_execute( iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h index f7ee2d88d7d6..0a6f79087b1b 100644 --- a/runtime/src/iree/hal/device.h +++ b/runtime/src/iree/hal/device.h @@ -19,6 +19,7 @@ #include "iree/hal/event.h" #include "iree/hal/executable_cache.h" #include "iree/hal/fence.h" +#include "iree/hal/file.h" #include "iree/hal/pipeline_layout.h" #include "iree/hal/resource.h" #include "iree/hal/semaphore.h" @@ -404,6 +405,47 @@ IREE_API_EXPORT iree_status_t iree_hal_device_queue_dealloca( const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* buffer); +// Enqueues a single queue-ordered copy operation. +// +// WARNING: individual copies have a high overhead and batching should be +// performed by the caller instead of calling this multiple times. The +// iree_hal_create_transfer_command_buffer utility makes it easy to create +// batches of transfer operations (fill, copy, update) and is only a few lines +// more code. +IREE_API_EXPORT iree_status_t iree_hal_device_queue_copy( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length); + +// Enqueues a file read operation that streams a segment of the |source_file| +// defined by the |source_offset| and |length| into the HAL |target_buffer| at +// the specified |target_offset|. The |queue_affinity| should be set to where +// the target buffer will be consumed. The source file must have read permission +// and the target buffer must have transfer-target usage. +IREE_API_EXPORT iree_status_t iree_hal_device_queue_read( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags); + +// Enqueues a file write operation that streams a segment of the HAL +// |source_buffer| defined by the |source_offset| and |length| into the +// |target_file| at the specified |target_offset|. The |queue_affinity| should +// be set to where the source buffer was produced. The source buffer must have +// transfer-source usage and the target file must have write permission. +IREE_API_EXPORT iree_status_t iree_hal_device_queue_write( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags); + // Executes zero or more command buffers on a device queue. // The command buffers are executed in order as if they were recorded as one. // No commands will execute until the wait fence has been reached and the signal @@ -546,6 +588,13 @@ typedef struct iree_hal_device_vtable_t { iree_hal_device_t* device, iree_string_view_t identifier, iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache); + iree_status_t(IREE_API_PTR* import_file)( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file); + iree_status_t(IREE_API_PTR* create_pipeline_layout)( iree_hal_device_t* device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -580,6 +629,22 @@ typedef struct iree_hal_device_vtable_t { const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* buffer); + iree_status_t(IREE_API_PTR* queue_read)( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags); + + iree_status_t(IREE_API_PTR* queue_write)( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags); + iree_status_t(IREE_API_PTR* queue_execute)( iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, diff --git a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel index ad31c08756d8..db2b8c7d7fc5 100644 --- a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel @@ -59,6 +59,8 @@ iree_runtime_cc_library( "//runtime/src/iree/hal/utils:buffer_transfer", "//runtime/src/iree/hal/utils:collective_batch", "//runtime/src/iree/hal/utils:deferred_command_buffer", + "//runtime/src/iree/hal/utils:file_transfer", + "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/hal/utils:resource_set", "//runtime/src/iree/hal/utils:semaphore_base", "//runtime/src/iree/schemas:cuda_executable_def_c_fbs", diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt index eb225bd5185d..55b02f522249 100644 --- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt @@ -56,6 +56,8 @@ iree_cc_library( iree::hal::utils::buffer_transfer iree::hal::utils::collective_batch iree::hal::utils::deferred_command_buffer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::cuda_executable_def_c_fbs diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c index 7ab2864021ad..e81d6c3b1e17 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_allocator.c @@ -25,6 +25,7 @@ typedef struct iree_hal_cuda_allocator_t { CUstream stream; iree_hal_cuda_memory_pools_t* pools; bool supports_concurrent_managed_access; + bool supports_read_only_host_register; IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) } iree_hal_cuda_allocator_t; @@ -60,13 +61,27 @@ iree_status_t iree_hal_cuda_allocator_create( &supports_concurrent_managed_access, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, device), "cuDeviceGetAttribute")); - IREE_TRACE_ZONE_APPEND_TEXT( z0, supports_concurrent_managed_access ? "has CONCURRENT_MANAGED_ACCESS" : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " "device-local + host-visible memory)"); + // We can only provide the CU_MEMHOSTREGISTER_READ_ONLY flag when importing + // host memory if it's supported. + int supports_read_only_host_register = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + CU_RESULT_TO_STATUS( + context->syms, + cuDeviceGetAttribute( + &supports_read_only_host_register, + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED, device), + "cuDeviceGetAttribute")); + IREE_TRACE_ZONE_APPEND_TEXT(z0, supports_read_only_host_register + ? "has READ_ONLY_HOST_REGISTER_SUPPORTED" + : "no READ_ONLY_HOST_REGISTER_SUPPORTED"); + iree_hal_cuda_allocator_t* allocator = NULL; iree_status_t status = iree_allocator_malloc( context->host_allocator, sizeof(*allocator), (void**)&allocator); @@ -80,6 +95,8 @@ iree_status_t iree_hal_cuda_allocator_create( allocator->pools = pools; allocator->supports_concurrent_managed_access = supports_concurrent_managed_access != 0; + allocator->supports_read_only_host_register = + supports_read_only_host_register != 0; *out_allocator = (iree_hal_allocator_t*)allocator; } @@ -409,7 +426,7 @@ static iree_status_t iree_hal_cuda_allocator_allocate_buffer( &allocator->statistics, compat_params.type, allocation_size)); *out_buffer = buffer; } else { - if (!buffer) { + if (!buffer && (device_ptr || host_ptr)) { iree_hal_cuda_buffer_free(allocator->context, buffer_type, device_ptr, host_ptr); } else { @@ -518,16 +535,10 @@ static iree_status_t iree_hal_cuda_allocator_import_buffer( } buffer_type = IREE_HAL_CUDA_BUFFER_TYPE_HOST_REGISTERED; host_ptr = external_buffer->handle.host_allocation.ptr; - uint32_t register_flags = 0; - if (compat_params.access == IREE_HAL_MEMORY_ACCESS_READ) { - register_flags = CU_MEMHOSTREGISTER_READ_ONLY; - } - if (iree_any_bit_set(compat_params.usage, - IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS | - IREE_HAL_BUFFER_USAGE_DISPATCH_UNIFORM_READ | - IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | - IREE_HAL_BUFFER_USAGE_DISPATCH_IMAGE)) { - register_flags = CU_MEMHOSTREGISTER_DEVICEMAP; + uint32_t register_flags = CU_MEMHOSTREGISTER_DEVICEMAP; + if (compat_params.access == IREE_HAL_MEMORY_ACCESS_READ && + allocator->supports_read_only_host_register) { + register_flags |= CU_MEMHOSTREGISTER_READ_ONLY; } status = CU_RESULT_TO_STATUS( allocator->context->syms, @@ -569,7 +580,7 @@ static iree_status_t iree_hal_cuda_allocator_import_buffer( if (iree_status_is_ok(status)) { *out_buffer = buffer; } else { - if (!buffer) { + if (!buffer && (device_ptr || host_ptr)) { iree_hal_cuda_buffer_free(allocator->context, buffer_type, device_ptr, host_ptr); } else { diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c b/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c index c39e16046b55..bcb1ad742536 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c @@ -87,9 +87,11 @@ static iree_status_t iree_hal_cuda_buffer_map_range( IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( iree_hal_buffer_memory_type(base_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(base_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(base_buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 832b6385bd2b..a6c37a79049a 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -28,6 +28,8 @@ #include "iree/hal/drivers/cuda/tracing.h" #include "iree/hal/utils/buffer_transfer.h" #include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" //===----------------------------------------------------------------------===// // iree_hal_cuda_device_t @@ -492,6 +494,23 @@ static iree_status_t iree_hal_cuda_device_create_executable_cache( &device->context_wrapper, identifier, out_executable_cache); } +static iree_status_t iree_hal_cuda_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_cuda_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -542,7 +561,8 @@ static iree_status_t iree_hal_cuda_device_queue_alloca( // If pools are not supported we allocate a buffer as normal from whatever // allocator is set on the device. iree_status_t status = iree_ok_status(); - if (device->supports_memory_pools) { + if (device->supports_memory_pools && + !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda_memory_pools_alloca(&device->memory_pools, device->stream, pool, params, allocation_size, out_buffer); @@ -595,6 +615,48 @@ static iree_status_t iree_hal_cuda_device_queue_dealloca( return status; } +static iree_status_t iree_hal_cuda_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_cuda_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_cuda_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -679,6 +741,7 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = { iree_hal_cuda_device_create_descriptor_set_layout, .create_event = iree_hal_cuda_device_create_event, .create_executable_cache = iree_hal_cuda_device_create_executable_cache, + .import_file = iree_hal_cuda_device_import_file, .create_pipeline_layout = iree_hal_cuda_device_create_pipeline_layout, .create_semaphore = iree_hal_cuda_device_create_semaphore, .query_semaphore_compatibility = @@ -686,6 +749,8 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = { .transfer_range = iree_hal_device_submit_transfer_range_and_wait, .queue_alloca = iree_hal_cuda_device_queue_alloca, .queue_dealloca = iree_hal_cuda_device_queue_dealloca, + .queue_read = iree_hal_cuda_device_queue_read, + .queue_write = iree_hal_cuda_device_queue_write, .queue_execute = iree_hal_cuda_device_queue_execute, .queue_flush = iree_hal_cuda_device_queue_flush, .wait_semaphores = iree_hal_cuda_device_wait_semaphores, diff --git a/runtime/src/iree/hal/drivers/local_sync/BUILD.bazel b/runtime/src/iree/hal/drivers/local_sync/BUILD.bazel index 8c7bfbd8ff22..923cd61740f4 100644 --- a/runtime/src/iree/hal/drivers/local_sync/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/local_sync/BUILD.bazel @@ -37,6 +37,8 @@ iree_runtime_cc_library( "//runtime/src/iree/hal/local:executable_environment", "//runtime/src/iree/hal/utils:buffer_transfer", "//runtime/src/iree/hal/utils:deferred_command_buffer", + "//runtime/src/iree/hal/utils:file_transfer", + "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/hal/utils:semaphore_base", ], ) diff --git a/runtime/src/iree/hal/drivers/local_sync/CMakeLists.txt b/runtime/src/iree/hal/drivers/local_sync/CMakeLists.txt index 7d2165426680..d28f7171b70e 100644 --- a/runtime/src/iree/hal/drivers/local_sync/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/local_sync/CMakeLists.txt @@ -34,6 +34,8 @@ iree_cc_library( iree::hal::local::executable_environment iree::hal::utils::buffer_transfer iree::hal::utils::deferred_command_buffer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::semaphore_base PUBLIC ) diff --git a/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt index 8aaf3ee5d230..6f31a1aa586d 100644 --- a/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt @@ -28,7 +28,7 @@ if(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF) DEPS iree::hal::drivers::local_sync::registration EXCLUDED_TESTS - "semaphore_submission" # SubmitWithWait hangs? + "semaphore_submission" # SubmitWithWait hangs (requires async) LABELS driver=local-sync ) @@ -51,7 +51,7 @@ if(IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE) DEPS iree::hal::drivers::local_sync::registration EXCLUDED_TESTS - "semaphore_submission" # SubmitWithWait hangs? + "semaphore_submission" # SubmitWithWait hangs (requires async) LABELS driver=local-sync ) diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c index 4c9ae8a2b541..7827f87ea2ef 100644 --- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c +++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c @@ -20,6 +20,8 @@ #include "iree/hal/local/local_pipeline_layout.h" #include "iree/hal/utils/buffer_transfer.h" #include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" typedef struct iree_hal_sync_device_t { iree_hal_resource_t resource; @@ -262,6 +264,23 @@ static iree_status_t iree_hal_sync_device_create_executable_cache( iree_hal_device_host_allocator(base_device), out_executable_cache); } +static iree_status_t iree_hal_sync_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_sync_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -358,6 +377,48 @@ static iree_status_t iree_hal_sync_device_apply_deferred_command_buffers( return iree_ok_status(); } +static iree_status_t iree_hal_sync_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_sync_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_sync_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -437,6 +498,7 @@ static const iree_hal_device_vtable_t iree_hal_sync_device_vtable = { iree_hal_sync_device_create_descriptor_set_layout, .create_event = iree_hal_sync_device_create_event, .create_executable_cache = iree_hal_sync_device_create_executable_cache, + .import_file = iree_hal_sync_device_import_file, .create_pipeline_layout = iree_hal_sync_device_create_pipeline_layout, .create_semaphore = iree_hal_sync_device_create_semaphore, .query_semaphore_compatibility = @@ -444,6 +506,8 @@ static const iree_hal_device_vtable_t iree_hal_sync_device_vtable = { .transfer_range = iree_hal_device_transfer_mappable_range, .queue_alloca = iree_hal_sync_device_queue_alloca, .queue_dealloca = iree_hal_sync_device_queue_dealloca, + .queue_read = iree_hal_sync_device_queue_read, + .queue_write = iree_hal_sync_device_queue_write, .queue_execute = iree_hal_sync_device_queue_execute, .queue_flush = iree_hal_sync_device_queue_flush, .wait_semaphores = iree_hal_sync_device_wait_semaphores, diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_semaphore.c b/runtime/src/iree/hal/drivers/local_sync/sync_semaphore.c index a5f177600060..58a7d8e8e2a0 100644 --- a/runtime/src/iree/hal/drivers/local_sync/sync_semaphore.c +++ b/runtime/src/iree/hal/drivers/local_sync/sync_semaphore.c @@ -13,9 +13,6 @@ #include "iree/hal/utils/semaphore_base.h" -// Sentinel used the semaphore has failed and an error status is set. -#define IREE_HAL_SYNC_SEMAPHORE_FAILURE_VALUE UINT64_MAX - //===----------------------------------------------------------------------===// // iree_hal_sync_semaphore_state_t //===----------------------------------------------------------------------===// @@ -48,7 +45,7 @@ typedef struct iree_hal_sync_semaphore_t { // than trying to make the entire structure lock-free. iree_slim_mutex_t mutex; - // Current signaled value. May be IREE_HAL_SYNC_SEMAPHORE_FAILURE_VALUE to + // Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to // indicate that the semaphore has been signaled for failure and // |failure_status| contains the error. uint64_t current_value; @@ -119,7 +116,7 @@ static iree_status_t iree_hal_sync_semaphore_query( *out_value = semaphore->current_value; iree_status_t status = iree_ok_status(); - if (*out_value >= IREE_HAL_SYNC_SEMAPHORE_FAILURE_VALUE) { + if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_status_clone(semaphore->failure_status); } @@ -193,14 +190,14 @@ static void iree_hal_sync_semaphore_fail(iree_hal_semaphore_t* base_semaphore, } // Signal to our failure sentinel value. - semaphore->current_value = IREE_HAL_SYNC_SEMAPHORE_FAILURE_VALUE; + semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; semaphore->failure_status = status; iree_slim_mutex_unlock(&semaphore->mutex); // Notify timepoints of the failure. - iree_hal_semaphore_notify(&semaphore->base, - IREE_HAL_SYNC_SEMAPHORE_FAILURE_VALUE, status_code); + iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, + status_code); iree_notification_post(&semaphore->shared_state->notification, IREE_ALL_WAITERS); diff --git a/runtime/src/iree/hal/drivers/local_task/BUILD.bazel b/runtime/src/iree/hal/drivers/local_task/BUILD.bazel index 4d6a930f6c32..0171955ced70 100644 --- a/runtime/src/iree/hal/drivers/local_task/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/local_task/BUILD.bazel @@ -48,6 +48,8 @@ iree_runtime_cc_library( "//runtime/src/iree/hal/local:executable_environment", "//runtime/src/iree/hal/local:executable_library", "//runtime/src/iree/hal/utils:buffer_transfer", + "//runtime/src/iree/hal/utils:file_transfer", + "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/hal/utils:resource_set", "//runtime/src/iree/hal/utils:semaphore_base", "//runtime/src/iree/task", diff --git a/runtime/src/iree/hal/drivers/local_task/CMakeLists.txt b/runtime/src/iree/hal/drivers/local_task/CMakeLists.txt index 8b4340edfea4..d326169a139b 100644 --- a/runtime/src/iree/hal/drivers/local_task/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/local_task/CMakeLists.txt @@ -42,6 +42,8 @@ iree_cc_library( iree::hal::local::executable_environment iree::hal::local::executable_library iree::hal::utils::buffer_transfer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::task diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c index 5f35d62674a3..ccb5ae053d07 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_device.c +++ b/runtime/src/iree/hal/drivers/local_task/task_device.c @@ -20,6 +20,8 @@ #include "iree/hal/local/local_executable_cache.h" #include "iree/hal/local/local_pipeline_layout.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" typedef struct iree_hal_task_device_t { iree_hal_resource_t resource; @@ -325,6 +327,23 @@ static iree_status_t iree_hal_task_device_create_executable_cache( iree_hal_device_host_allocator(base_device), out_executable_cache); } +static iree_status_t iree_hal_task_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_task_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -387,6 +406,48 @@ static iree_status_t iree_hal_task_device_queue_dealloca( return iree_ok_status(); } +static iree_status_t iree_hal_task_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_task_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_task_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -457,6 +518,7 @@ static const iree_hal_device_vtable_t iree_hal_task_device_vtable = { iree_hal_task_device_create_descriptor_set_layout, .create_event = iree_hal_task_device_create_event, .create_executable_cache = iree_hal_task_device_create_executable_cache, + .import_file = iree_hal_task_device_import_file, .create_pipeline_layout = iree_hal_task_device_create_pipeline_layout, .create_semaphore = iree_hal_task_device_create_semaphore, .query_semaphore_compatibility = @@ -464,6 +526,8 @@ static const iree_hal_device_vtable_t iree_hal_task_device_vtable = { .transfer_range = iree_hal_device_transfer_mappable_range, .queue_alloca = iree_hal_task_device_queue_alloca, .queue_dealloca = iree_hal_task_device_queue_dealloca, + .queue_read = iree_hal_task_device_queue_read, + .queue_write = iree_hal_task_device_queue_write, .queue_execute = iree_hal_task_device_queue_execute, .queue_flush = iree_hal_task_device_queue_flush, .wait_semaphores = iree_hal_task_device_wait_semaphores, diff --git a/runtime/src/iree/hal/drivers/local_task/task_semaphore.c b/runtime/src/iree/hal/drivers/local_task/task_semaphore.c index 7a6728fe62ea..89eb38347873 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_semaphore.c +++ b/runtime/src/iree/hal/drivers/local_task/task_semaphore.c @@ -15,9 +15,6 @@ #include "iree/base/internal/wait_handle.h" #include "iree/hal/utils/semaphore_base.h" -// Sentinel used the semaphore has failed and an error status is set. -#define IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE UINT64_MAX - //===----------------------------------------------------------------------===// // iree_hal_task_timepoint_t //===----------------------------------------------------------------------===// @@ -59,7 +56,7 @@ typedef struct iree_hal_task_semaphore_t { // than trying to make the entire structure lock-free. iree_slim_mutex_t mutex; - // Current signaled value. May be IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE to + // Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to // indicate that the semaphore has been signaled for failure and // |failure_status| contains the error. uint64_t current_value; @@ -135,7 +132,7 @@ static iree_status_t iree_hal_task_semaphore_query( *out_value = semaphore->current_value; iree_status_t status = iree_ok_status(); - if (*out_value >= IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE) { + if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_status_clone(semaphore->failure_status); } @@ -189,14 +186,14 @@ static void iree_hal_task_semaphore_fail(iree_hal_semaphore_t* base_semaphore, } // Signal to our failure sentinel value. - semaphore->current_value = IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE; + semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; semaphore->failure_status = status; iree_slim_mutex_unlock(&semaphore->mutex); // Notify timepoints - note that this must happen outside the lock. - iree_hal_semaphore_notify(&semaphore->base, - IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE, status_code); + iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, + status_code); } // Acquires a timepoint waiting for the given value. diff --git a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt index 384005dff4bf..8e78eb059e62 100644 --- a/runtime/src/iree/hal/drivers/metal/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/metal/CMakeLists.txt @@ -42,6 +42,8 @@ iree_cc_library( iree::hal iree::hal::drivers::metal::builtin iree::hal::utils::buffer_transfer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::resource_set iree::schemas::metal_executable_def_c_fbs "-framework Foundation" diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m index 4deb54df7fb0..da93dc387fe7 100644 --- a/runtime/src/iree/hal/drivers/metal/metal_device.m +++ b/runtime/src/iree/hal/drivers/metal/metal_device.m @@ -18,6 +18,8 @@ #include "iree/hal/drivers/metal/shared_event.h" #include "iree/hal/drivers/metal/staging_buffer.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" #include "iree/hal/utils/resource_set.h" typedef struct iree_hal_metal_device_t { @@ -275,6 +277,19 @@ static iree_status_t iree_hal_metal_device_create_executable_cache( device->host_allocator, out_executable_cache); } +static iree_status_t iree_hal_metal_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap(queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_metal_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, iree_hal_descriptor_set_layout_t* const* set_layouts, @@ -329,6 +344,46 @@ static iree_status_t iree_hal_metal_device_queue_dealloca( signal_semaphore_list); } +static iree_status_t iree_hal_metal_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_file_t* source_file, + uint64_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, source_file, + source_offset, target_buffer, target_offset, length, flags, iree_loop_inline(&loop_status), + chunk_count, chunk_size)); + return loop_status; +} + +static iree_status_t iree_hal_metal_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* source_buffer, + iree_device_size_t source_offset, iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, source_buffer, + source_offset, target_file, target_offset, length, flags, iree_loop_inline(&loop_status), + chunk_count, chunk_size)); + return loop_status; +} + static iree_status_t iree_hal_metal_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -494,12 +549,15 @@ static iree_status_t iree_hal_metal_device_profiling_end(iree_hal_device_t* base .create_descriptor_set_layout = iree_hal_metal_device_create_descriptor_set_layout, .create_event = iree_hal_metal_device_create_event, .create_executable_cache = iree_hal_metal_device_create_executable_cache, + .import_file = iree_hal_metal_device_import_file, .create_pipeline_layout = iree_hal_metal_device_create_pipeline_layout, .create_semaphore = iree_hal_metal_device_create_semaphore, .query_semaphore_compatibility = iree_hal_metal_device_query_semaphore_compatibility, .transfer_range = iree_hal_device_submit_transfer_range_and_wait, .queue_alloca = iree_hal_metal_device_queue_alloca, .queue_dealloca = iree_hal_metal_device_queue_dealloca, + .queue_read = iree_hal_metal_device_queue_read, + .queue_write = iree_hal_metal_device_queue_write, .queue_execute = iree_hal_metal_device_queue_execute, .queue_flush = iree_hal_metal_device_queue_flush, .wait_semaphores = iree_hal_metal_device_wait_semaphores, diff --git a/runtime/src/iree/hal/drivers/metal/shared_event.m b/runtime/src/iree/hal/drivers/metal/shared_event.m index fee92a13d3c4..f741f2ea3a63 100644 --- a/runtime/src/iree/hal/drivers/metal/shared_event.m +++ b/runtime/src/iree/hal/drivers/metal/shared_event.m @@ -95,7 +95,7 @@ static iree_status_t iree_hal_metal_shared_event_query(iree_hal_semaphore_t* bas uint64_t* out_value) { iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore); uint64_t value = semaphore->shared_event.signaledValue; - if (IREE_UNLIKELY(value == UINT64_MAX)) { + if (IREE_UNLIKELY(value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE)) { iree_status_t status = iree_ok_status(); iree_slim_mutex_lock(&semaphore->state_mutex); status = semaphore->failure_state; @@ -110,7 +110,7 @@ static iree_status_t iree_hal_metal_shared_event_signal(iree_hal_semaphore_t* ba uint64_t new_value) { iree_hal_metal_shared_event_t* semaphore = iree_hal_metal_shared_event_cast(base_semaphore); uint64_t value = semaphore->shared_event.signaledValue; - if (IREE_UNLIKELY(value == UINT64_MAX)) { + if (IREE_UNLIKELY(value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE)) { iree_status_t status = iree_ok_status(); iree_slim_mutex_lock(&semaphore->state_mutex); status = semaphore->failure_state; @@ -128,7 +128,7 @@ static void iree_hal_metal_shared_event_fail(iree_hal_semaphore_t* base_semaphor iree_slim_mutex_lock(&semaphore->state_mutex); semaphore->failure_state = status; - semaphore->shared_event.signaledValue = UINT64_MAX; + semaphore->shared_event.signaledValue = IREE_HAL_SEMAPHORE_FAILURE_VALUE; iree_slim_mutex_unlock(&semaphore->state_mutex); IREE_TRACE_ZONE_END(z0); @@ -182,7 +182,7 @@ static iree_status_t iree_hal_metal_shared_event_wait(iree_hal_semaphore_t* base [semaphore->shared_event notifyListener:semaphore->event_listener atValue:value block:^(id se, uint64_t v) { - if (v == UINT64_MAX) did_fail = true; + if (v >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) did_fail = true; dispatch_semaphore_signal(work_done); }]; @@ -251,7 +251,7 @@ iree_status_t iree_hal_metal_shared_event_multi_wait( atValue:semaphore_list->payload_values[i] block:^(id se, uint64_t v) { // Fail as a whole if any participating semaphore failed. - if (v == UINT64_MAX) did_fail = true; + if (v >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) did_fail = true; int32_t old_value = iree_atomic_fetch_add_int32( &wait_count, 1, iree_memory_order_release); diff --git a/runtime/src/iree/hal/drivers/vulkan/BUILD.bazel b/runtime/src/iree/hal/drivers/vulkan/BUILD.bazel index a7f0139b49e1..280e432eb552 100644 --- a/runtime/src/iree/hal/drivers/vulkan/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/vulkan/BUILD.bazel @@ -81,6 +81,8 @@ iree_runtime_cc_library( "//runtime/src/iree/hal/drivers/vulkan/util:intrusive_list", "//runtime/src/iree/hal/drivers/vulkan/util:ref_ptr", "//runtime/src/iree/hal/utils:buffer_transfer", + "//runtime/src/iree/hal/utils:file_transfer", + "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/hal/utils:resource_set", "//runtime/src/iree/hal/utils:semaphore_base", "//runtime/src/iree/schemas:spirv_executable_def_c_fbs", diff --git a/runtime/src/iree/hal/drivers/vulkan/CMakeLists.txt b/runtime/src/iree/hal/drivers/vulkan/CMakeLists.txt index f9893dbe4e79..86304d57849f 100644 --- a/runtime/src/iree/hal/drivers/vulkan/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/vulkan/CMakeLists.txt @@ -76,6 +76,8 @@ iree_cc_library( iree::hal::drivers::vulkan::util::intrusive_list iree::hal::drivers::vulkan::util::ref_ptr iree::hal::utils::buffer_transfer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::schemas::spirv_executable_def_c_fbs diff --git a/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt index c6b241e18603..4df998e8348c 100644 --- a/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt @@ -17,6 +17,8 @@ iree_hal_cts_test_suite( "\"SPVE\"" DEPS iree::hal::drivers::vulkan::registration + EXCLUDED_TESTS + "semaphore_submission" # SubmitWithWait hangs (requires async) LABELS driver=vulkan ) diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc index 48bbd4001dd0..a445760b0136 100644 --- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc +++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc @@ -534,6 +534,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_fill_buffer( iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); VkBuffer target_device_buffer = iree_hal_vulkan_buffer_handle(target_buffer); + IREE_VULKAN_TRACE_ZONE_BEGIN(command_buffer->tracing_context, + command_buffer->handle); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( command_buffer->resource_set, 1, &target_buffer)); @@ -579,6 +582,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_fill_buffer( length, dword_pattern); } + IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->handle); + return iree_ok_status(); } @@ -590,6 +596,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_update_buffer( iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); VkBuffer target_device_buffer = iree_hal_vulkan_buffer_handle(target_buffer); + IREE_VULKAN_TRACE_ZONE_BEGIN(command_buffer->tracing_context, + command_buffer->handle); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( command_buffer->resource_set, 1, &target_buffer)); @@ -612,6 +621,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_update_buffer( length -= chunk_length; } + IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->handle); + return iree_ok_status(); } @@ -625,6 +637,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_copy_buffer( VkBuffer source_device_buffer = iree_hal_vulkan_buffer_handle(source_buffer); VkBuffer target_device_buffer = iree_hal_vulkan_buffer_handle(target_buffer); + IREE_VULKAN_TRACE_ZONE_BEGIN(command_buffer->tracing_context, + command_buffer->handle); + const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer}; IREE_RETURN_IF_ERROR( iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers)); @@ -637,6 +652,9 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_copy_buffer( source_device_buffer, target_device_buffer, 1, ®ion); + IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->handle); + return iree_ok_status(); } diff --git a/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc b/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc index 70e2ac4b1271..620330b58876 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc @@ -322,7 +322,7 @@ static void iree_hal_vulkan_native_allocator_deallocate_buffer( iree_hal_allocator_t* IREE_RESTRICT base_allocator, iree_hal_buffer_t* IREE_RESTRICT base_buffer) { iree_hal_vulkan_native_allocator_t* allocator = - iree_hal_vulkan_native_allocator_cast(base_buffer->device_allocator); + iree_hal_vulkan_native_allocator_cast(base_allocator); (void)allocator; iree_hal_allocator_statistics_record_free(&allocator->statistics, base_buffer->memory_type, diff --git a/runtime/src/iree/hal/drivers/vulkan/native_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/native_buffer.cc index 7c58012cd59a..f7ca492ca660 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_buffer.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_buffer.cc @@ -102,9 +102,11 @@ static iree_status_t iree_hal_vulkan_native_buffer_map_range( IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( iree_hal_buffer_memory_type(base_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(base_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(base_buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); // TODO(benvanik): map VK_WHOLE_SIZE and subset ourselves? may need to get // around some minimum mapping alignment rules. diff --git a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc index c227c211aaf1..d882cfb47bef 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc @@ -15,26 +15,6 @@ #include "iree/hal/drivers/vulkan/util/ref_ptr.h" #include "iree/hal/utils/semaphore_base.h" -// The maximum valid payload value of an iree_hal_semaphore_t. -// Payload values larger than this indicate that the semaphore has failed. -// -// This originates from Vulkan having a lower-bound of INT_MAX for -// maxTimelineSemaphoreValueDifference and many Android devices only supporting -// that lower-bound. At ~100 signals per second it'll take 1.5+ years to -// saturate. We may increase this value at some point but so long as there are -// some devices in the wild that may have this limitation we can ensure better -// consistency across the backends by observing this. -// -// The major mitigation here is that in proper usage of IREE there are no -// semaphores that are implicitly referenced by multiple VMs (each creates their -// own internally) and in a multitenant system each session should have its own -// semaphores - so even if the process lives for years it's highly unlikely any -// particular session does. Whatever, 640K is enough for anyone. -// -// See: -// https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference -#define IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE (2147483647ull - 1) - using namespace iree::hal::vulkan; typedef struct iree_hal_vulkan_native_semaphore_t { @@ -146,7 +126,7 @@ static iree_status_t iree_hal_vulkan_native_semaphore_query( "vkGetSemaphoreCounterValue")); // If the semaphore failed then clone the status so we can report it. - if (value > IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE) { + if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_status_t failure_status = (iree_status_t)iree_atomic_load_intptr( &semaphore->failure_status, iree_memory_order_acquire); if (iree_status_is_ok(failure_status)) { @@ -211,7 +191,7 @@ static void iree_hal_vulkan_native_semaphore_fail( signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; signal_info.pNext = NULL; signal_info.semaphore = semaphore->handle; - signal_info.value = IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE + 1; + signal_info.value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; // NOTE: we don't care about the result in case of failures as we are // failing and the caller will likely be tearing everything down anyway. semaphore->logical_device->syms()->vkSignalSemaphore( diff --git a/runtime/src/iree/hal/drivers/vulkan/tracing.h b/runtime/src/iree/hal/drivers/vulkan/tracing.h index d2fe101ef3e4..520ca813bfbc 100644 --- a/runtime/src/iree/hal/drivers/vulkan/tracing.h +++ b/runtime/src/iree/hal/drivers/vulkan/tracing.h @@ -116,12 +116,12 @@ void iree_hal_vulkan_tracing_zone_end_impl( iree_hal_vulkan_tracing_context_t* context, VkCommandBuffer command_buffer); // Begins a new zone with the parent function name. -#define IREE_VULKAN_TRACE_ZONE_BEGIN(context, command_buffer) \ - static const iree_tracing_location_t TracyConcat( \ - __tracy_source_location, __LINE__) = {name_literal, __FUNCTION__, \ - __FILE__, (uint32_t)__LINE__, 0}; \ - iree_hal_vulkan_tracing_zone_begin_impl( \ - context, command_buffer, \ +#define IREE_VULKAN_TRACE_ZONE_BEGIN(context, command_buffer) \ + static const iree_tracing_location_t TracyConcat( \ + __tracy_source_location, __LINE__) = {NULL, __FUNCTION__, __FILE__, \ + (uint32_t)__LINE__, 0}; \ + iree_hal_vulkan_tracing_zone_begin_impl( \ + context, command_buffer, \ &TracyConcat(__tracy_source_location, __LINE__)); // Begins an externally defined zone with a dynamic source location. diff --git a/runtime/src/iree/hal/drivers/vulkan/vma_allocator.cc b/runtime/src/iree/hal/drivers/vulkan/vma_allocator.cc index 3a62218ee884..3b31640dd5e8 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vma_allocator.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vma_allocator.cc @@ -119,9 +119,11 @@ static iree_status_t iree_hal_vulkan_vma_buffer_map_range( IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( iree_hal_buffer_memory_type(base_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_validate_usage(iree_hal_buffer_allowed_usage(base_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(base_buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); uint8_t* data_ptr = nullptr; VK_RETURN_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index 10a6d6dff8f6..b2a5ef1578a5 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -33,6 +33,8 @@ #include "iree/hal/drivers/vulkan/util/ref_ptr.h" #include "iree/hal/drivers/vulkan/vma_allocator.h" #include "iree/hal/utils/buffer_transfer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" using namespace iree::hal::vulkan; @@ -1206,6 +1208,23 @@ static iree_status_t iree_hal_vulkan_device_create_executable_cache( device->logical_device, identifier, out_executable_cache); } +static iree_status_t iree_hal_vulkan_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + if (external_file->type != IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, external_file->handle.host_allocation, + release_callback, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + static iree_status_t iree_hal_vulkan_device_create_pipeline_layout( iree_hal_device_t* base_device, iree_host_size_t push_constants, iree_host_size_t set_layout_count, @@ -1267,6 +1286,48 @@ static iree_status_t iree_hal_vulkan_device_queue_dealloca( return iree_ok_status(); } +static iree_status_t iree_hal_vulkan_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + /*.loop=*/iree_loop_inline(&loop_status), + /*.chunk_count=*/IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + /*.chunk_size=*/IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_vulkan_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + /*.loop=*/iree_loop_inline(&loop_status), + /*.chunk_count=*/IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + /*.chunk_size=*/IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + static iree_status_t iree_hal_vulkan_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, @@ -1283,7 +1344,10 @@ static iree_status_t iree_hal_vulkan_device_queue_execute( /*.command_buffers=*/command_buffers, /*.signal_semaphores=*/signal_semaphore_list, }; - return queue->Submit(1, &batch); + IREE_RETURN_IF_ERROR(queue->Submit(1, &batch)); + // HACK: we don't track async resource lifetimes so we have to block. + return iree_hal_semaphore_list_wait(signal_semaphore_list, + iree_infinite_timeout()); } static iree_status_t iree_hal_vulkan_device_queue_flush( @@ -1383,6 +1447,7 @@ const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable = { /*.create_event=*/iree_hal_vulkan_device_create_event, /*.create_executable_cache=*/ iree_hal_vulkan_device_create_executable_cache, + /*.import_file=*/iree_hal_vulkan_device_import_file, /*.create_pipeline_layout=*/ iree_hal_vulkan_device_create_pipeline_layout, /*.create_semaphore=*/iree_hal_vulkan_device_create_semaphore, @@ -1391,6 +1456,8 @@ const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable = { /*.transfer_range=*/iree_hal_device_submit_transfer_range_and_wait, /*.queue_alloca=*/iree_hal_vulkan_device_queue_alloca, /*.queue_dealloca=*/iree_hal_vulkan_device_queue_dealloca, + /*.queue_read=*/iree_hal_vulkan_device_queue_read, + /*.queue_write=*/iree_hal_vulkan_device_queue_write, /*.queue_execute=*/iree_hal_vulkan_device_queue_execute, /*.queue_flush=*/iree_hal_vulkan_device_queue_flush, /*.wait_semaphores=*/iree_hal_vulkan_device_wait_semaphores, diff --git a/runtime/src/iree/hal/file.c b/runtime/src/iree/hal/file.c new file mode 100644 index 000000000000..aff4b09c912a --- /dev/null +++ b/runtime/src/iree/hal/file.c @@ -0,0 +1,36 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/file.h" + +#include + +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(file, method_name) \ + IREE_HAL_VTABLE_DISPATCH(file, iree_hal_file, method_name) + +IREE_HAL_API_RETAIN_RELEASE(file); + +IREE_API_EXPORT iree_status_t iree_hal_file_import( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(external_file); + IREE_ASSERT_ARGUMENT(out_file); + *out_file = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, import_file)( + device, queue_affinity, access, external_file, release_callback, + out_file); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/runtime/src/iree/hal/file.h b/runtime/src/iree/hal/file.h new file mode 100644 index 000000000000..8ce709558031 --- /dev/null +++ b/runtime/src/iree/hal/file.h @@ -0,0 +1,166 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_FILE_H_ +#define IREE_HAL_FILE_H_ + +#include + +#include "iree/base/api.h" +#include "iree/hal/allocator.h" +#include "iree/hal/buffer.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_t iree_hal_device_t; + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// A bitfield specifying how a file should be opened and the access allowed. +enum iree_hal_file_mode_bits_t { + // Opens the file if it exists on the file system. + IREE_HAL_FILE_MODE_OPEN = 1u << 0, +}; +typedef uint32_t iree_hal_file_mode_t; + +typedef void(IREE_API_PTR* iree_hal_file_release_fn_t)(void* user_data); + +// A callback issued when a file is released. +typedef struct { + // Callback function pointer. + iree_hal_file_release_fn_t fn; + // User data passed to the callback function. Unowned. + void* user_data; +} iree_hal_file_release_callback_t; + +// Returns a no-op file release callback that implies that no cleanup is +// required. +static inline iree_hal_file_release_callback_t +iree_hal_file_release_callback_null(void) { + iree_hal_file_release_callback_t callback = {NULL, NULL}; + return callback; +} + +// Defines the type of an external file handle. +// Each type may only be usable in a subset of implementations and platforms and +// may even vary based on the runtime device properties or file instance. +// +// See the notes on each type for requirements; compatibility often requires +// the handle to check and trying to import/export is the most reliable way to +// check for support. +typedef enum iree_hal_external_file_type_e { + IREE_HAL_EXTERNAL_FILE_TYPE_NONE = 0, + + // A fixed-size range of host memory. + // An imported/exported file does not own a reference to the memory and the + // caller is responsible for ensuring the memory remains live for as long as + // the iree_hal_file_t referencing it. + IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION, + + // TODO(benvanik): file descriptor, FILE*, HANDLE, etc. +} iree_hal_external_file_type_t; + +// Flags for controlling iree_hal_external_file_t implementation details. +enum iree_hal_external_file_flag_bits_t { + IREE_HAL_EXTERNAL_FILE_FLAG_NONE = 0u, +}; +typedef uint32_t iree_hal_external_file_flags_t; + +// Handle to a typed external file. +// This is a non-owning reference and the underlying file contents must remain +// valid for as long as the handle is in use. Some file types support internal +// referencing counting but in general ownership remains with the caller. +// See the type enum for more information. +typedef struct iree_hal_external_file_t { + // Type of the resource used to interpret the handle. + iree_hal_external_file_type_t type; + // Flags indicating file compatibility. + iree_hal_external_file_flags_t flags; + union { + // IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION + iree_byte_span_t host_allocation; + } handle; +} iree_hal_external_file_t; + +//===----------------------------------------------------------------------===// +// iree_hal_file_t +//===----------------------------------------------------------------------===// + +// A file handle usable with asynchronous device transfer operations. +// Files are used for bulk data upload and download and on some implementations +// may have hardware-optimized transfer paths. +// +// Implementations with support: +// CPU: file descriptors/HANDLEs +// CUDA: cuFile +// https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html +// Direct3D: IDStorageFileX +// https://learn.microsoft.com/en-us/gaming/gdk/_content/gc/system/overviews/directstorage/directstorage-overview +// Metal: MTLIOFileHandle +// https://developer.apple.com/documentation/metal/resource_loading?language=objc +// +// Some implementations may allow additional non-native contents to be wrapped +// in file handles to provide implementation-controlled transfer even if not +// hardware-accelerated. See iree_hal_file_import for more information. +typedef struct iree_hal_file_t iree_hal_file_t; + +// TODO(benvanik): support opening files from paths. +// IREE_API_EXPORT iree_status_t iree_hal_file_open( +// iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, +// iree_hal_file_mode_t mode, iree_hal_memory_access_t access, +// iree_string_view_t path, iree_hal_file_t** out_file); + +// Imports an externally-owned |external_file| handle for use on |device|. +// +// Access checks will be performed against the provided |access| bits and +// callers must ensure the access is accurate (don't allow writes to read-only +// mapped memory, etc). +// +// The provided |external_file| handle is not owned and callers must either +// ensure it remains valid for the lifetime of the handle or retain it prior +// to calling and release it with the provided optional |release_callback|. +// The release callback allows the caller to listen for when the underlying +// resource is no longer in use by the HAL and can be used to perform lifetime +// management of the external file handle, file system synchronization, etc. +// +// |out_file| must be released by the caller. +// Fails with IREE_STATUS_UNAVAILABLE if the allocator cannot import the file. +// This may be due to unavailable device/platform capabilities or the properties +// of the external file handle. +IREE_API_EXPORT iree_status_t iree_hal_file_import( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, + iree_hal_external_file_t* IREE_RESTRICT external_file, + iree_hal_file_release_callback_t release_callback, + iree_hal_file_t** out_file); + +// Retains the given |file| for the caller. +IREE_API_EXPORT void iree_hal_file_retain(iree_hal_file_t* file); + +// Releases the given |file| from the caller. +IREE_API_EXPORT void iree_hal_file_release(iree_hal_file_t* file); + +//===----------------------------------------------------------------------===// +// iree_hal_file_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_file_vtable_t { + void(IREE_API_PTR* destroy)(iree_hal_file_t* IREE_RESTRICT file); +} iree_hal_file_vtable_t; +IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_file_vtable_t); + +IREE_API_EXPORT void iree_hal_file_destroy(iree_hal_file_t* file); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_FILE_H_ diff --git a/runtime/src/iree/hal/semaphore.c b/runtime/src/iree/hal/semaphore.c index 8cfaef94a6f4..b71caa91ee7b 100644 --- a/runtime/src/iree/hal/semaphore.c +++ b/runtime/src/iree/hal/semaphore.c @@ -10,7 +10,6 @@ #include "iree/hal/detail.h" #include "iree/hal/device.h" -#include "iree/hal/resource.h" //===----------------------------------------------------------------------===// // iree_hal_semaphore_t diff --git a/runtime/src/iree/hal/semaphore.h b/runtime/src/iree/hal/semaphore.h index 47c58cf70648..b9263b2fa136 100644 --- a/runtime/src/iree/hal/semaphore.h +++ b/runtime/src/iree/hal/semaphore.h @@ -23,6 +23,30 @@ typedef struct iree_hal_device_t iree_hal_device_t; // iree_hal_semaphore_t //===----------------------------------------------------------------------===// +// The maximum valid payload value of an iree_hal_semaphore_t. +// Payload values larger than this indicate that the semaphore has failed. +// +// This originates from Vulkan having a lower-bound of INT_MAX for +// maxTimelineSemaphoreValueDifference and many Android devices only supporting +// that lower-bound. At ~100 signals per second it'll take 1.5+ years to +// saturate. We may increase this value at some point but so long as there are +// some devices in the wild that may have this limitation we can ensure better +// consistency across the backends by observing this. +// +// The major mitigation here is that in proper usage of IREE there are no +// semaphores that are implicitly referenced by multiple VMs (each creates their +// own internally) and in a multitenant system each session should have its own +// semaphores - so even if the process lives for years it's highly unlikely any +// particular session does. Whatever, 640K is enough for anyone. +// +// In the future we may try to back this out and go back to UINT64_MAX. +// +// See: +// https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference +#define IREE_HAL_SEMAPHORE_MAX_VALUE (2147483647ull - 1) + +#define IREE_HAL_SEMAPHORE_FAILURE_VALUE (IREE_HAL_SEMAPHORE_MAX_VALUE + 1) + // Synchronization mechanism for host->device, device->host, host->host, // and device->device notification. Semaphores behave like Vulkan timeline // semaphores (or D3D12 fences) and contain a monotonically increasing diff --git a/runtime/src/iree/hal/utils/BUILD.bazel b/runtime/src/iree/hal/utils/BUILD.bazel index 16c3970edaa9..f8960d8a096c 100644 --- a/runtime/src/iree/hal/utils/BUILD.bazel +++ b/runtime/src/iree/hal/utils/BUILD.bazel @@ -80,6 +80,18 @@ iree_runtime_cc_library( ], ) +iree_runtime_cc_library( + name = "file_transfer", + srcs = ["file_transfer.c"], + hdrs = ["file_transfer.h"], + deps = [ + ":memory_file", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", + "//runtime/src/iree/hal", + ], +) + iree_runtime_cc_library( name = "libmpi", srcs = ["libmpi.c"], @@ -105,6 +117,16 @@ iree_runtime_cc_test( ], ) +iree_runtime_cc_library( + name = "memory_file", + srcs = ["memory_file.c"], + hdrs = ["memory_file.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/hal", + ], +) + iree_runtime_cc_library( name = "mpi_channel_provider", srcs = ["mpi_channel_provider.c"], diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt index 1075b3d1ff03..de31d3bfdb90 100644 --- a/runtime/src/iree/hal/utils/CMakeLists.txt +++ b/runtime/src/iree/hal/utils/CMakeLists.txt @@ -95,6 +95,21 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + file_transfer + HDRS + "file_transfer.h" + SRCS + "file_transfer.c" + DEPS + ::memory_file + iree::base + iree::base::internal + iree::hal + PUBLIC +) + iree_cc_library( NAME libmpi @@ -122,6 +137,19 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_library( + NAME + memory_file + HDRS + "memory_file.h" + SRCS + "memory_file.c" + DEPS + iree::base + iree::hal + PUBLIC +) + iree_cc_library( NAME mpi_channel_provider diff --git a/runtime/src/iree/hal/utils/buffer_transfer.c b/runtime/src/iree/hal/utils/buffer_transfer.c index 179a11b64b57..a7b61a892eb2 100644 --- a/runtime/src/iree/hal/utils/buffer_transfer.c +++ b/runtime/src/iree/hal/utils/buffer_transfer.c @@ -112,13 +112,13 @@ IREE_API_EXPORT iree_status_t iree_hal_device_submit_transfer_range_and_wait( (iree_all_bits_set(iree_hal_buffer_memory_type(source.device_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && iree_all_bits_set(iree_hal_buffer_allowed_usage(source.device_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); bool is_target_mappable = !target.device_buffer || (iree_all_bits_set(iree_hal_buffer_memory_type(target.device_buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && iree_all_bits_set(iree_hal_buffer_allowed_usage(target.device_buffer), - IREE_HAL_BUFFER_USAGE_MAPPING)); + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); if (is_source_mappable && is_target_mappable) { return iree_hal_device_transfer_mappable_range( device, source, source_offset, target, target_offset, data_length, diff --git a/runtime/src/iree/hal/utils/file_transfer.c b/runtime/src/iree/hal/utils/file_transfer.c new file mode 100644 index 000000000000..14cb8a2b4479 --- /dev/null +++ b/runtime/src/iree/hal/utils/file_transfer.c @@ -0,0 +1,935 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/utils/file_transfer.h" + +#include "iree/base/internal/math.h" +#include "iree/hal/utils/memory_file.h" + +//===----------------------------------------------------------------------===// +// Configuration +//===----------------------------------------------------------------------===// + +// TODO(benvanik): make these either compile-time configuration options so we +// can prune code paths or flags (somehow). + +#if !defined(IREE_HAL_TRANSFER_WORKER_LIMIT) +// Maximum number of workers that will be used. This is something we can derive +// from the transfer size and the loop; small transfers or synchronous loops +// should have 1 and we can measure to see how many others we need. +#define IREE_HAL_TRANSFER_WORKER_LIMIT 1 +#endif // !IREE_HAL_TRANSFER_WORKER_LIMIT + +#if !defined(IREE_HAL_TRANSFER_CHUNK_SIZE) +// Bytes per worker to stage chunks of data. Larger chunks will result in less +// overhead as fewer copy operations are required. +#define IREE_HAL_TRANSFER_CHUNK_SIZE (64 * 1024 * 1024) +#endif // !IREE_HAL_TRANSFER_CHUNK_SIZE + +#if !defined(IREE_HAL_TRANSFER_CHUNKS_PER_WORKER) +// Estimated number of chunks each worker should process used to determine how +// many workers are needed as part of a transfer operation. Larger numbers will +// reduce memory overhead at the cost of latency reductions. +#define IREE_HAL_TRANSFER_CHUNKS_PER_WORKER 8 +#endif // IREE_HAL_TRANSFER_CHUNKS_PER_WORKER + +//===----------------------------------------------------------------------===// +// iree_hal_transfer_operation_t +//===----------------------------------------------------------------------===// + +// TODO(benvanik): move to utils/ without relying on iree_hal_memory_file_t. + +// Maximum number of transfer workers that can be used; common usage should be +// 1-4 but on very large systems with lots of bandwidth we may be able to +// use more. +#define IREE_HAL_TRANSFER_WORKER_MAX_COUNT 64 + +// Each bit represents one worker within a transfer matching its ordinal in +// the operation workers array. +typedef uint64_t iree_hal_transfer_worker_bitmask_t; + +// Counts the total number of workers indicated by the given worker bitmask. +#define iree_hal_transfer_worker_live_count(bitmask) \ + iree_math_count_ones_u64(bitmask) + +// Describes the direction of a transfer operation. +typedef enum { + // Transferring from the file to the buffer (read). + IREE_HAL_TRANSFER_READ_FILE_TO_BUFFER = 0, + // Transferring from the buffer to the file (write). + IREE_HAL_TRANSFER_WRITE_BUFFER_TO_FILE, +} iree_hal_transfer_direction_t; + +typedef struct iree_hal_transfer_operation_t iree_hal_transfer_operation_t; + +// A worker greedily processing subranges of a larger transfer operation. +// Since transfers are 99% IO bound we avoid real threads and use workers as +// coroutines (or something like them): workers submit operations and schedule +// an async wait on a loop for when the operation completes on the device - when +// woken the worker will try to grab another subrange of the transfer and +// continue running. When there are no remaining subranges the workers will +// exit and when the last does the transfer is marked complete. +typedef struct iree_hal_transfer_worker_t { + // Parent operation this worker is a part of. + iree_hal_transfer_operation_t* operation; + // Used to associate tracing events with this worker. + IREE_TRACE(int32_t trace_id;) + // Aligned offset into the staging buffer of the worker storage. + iree_device_size_t staging_buffer_offset; + // Aligned length of the staging buffer storage reserved for the worker. + iree_device_size_t staging_buffer_length; + // Semaphore representing the timeline of the worker. The payload is a + // monotonically increasing operation count. + iree_hal_semaphore_t* semaphore; + // Pending timepoint representing an in-flight operation. Upon completion of + // the operation the semaphore will reach this value. + uint64_t pending_timepoint; + // Offset into the transfer operation this worker is currently processing. + iree_device_size_t pending_transfer_offset; + // Length of the current worker transfer; usually staging_buffer_length but + // may be less if this worker is processing the end of the file. + iree_device_size_t pending_transfer_length; +} iree_hal_transfer_worker_t; + +// Manages an asynchronous transfer operation. +typedef struct iree_hal_transfer_operation_t { + // Some loop implementations are re-entrant and we need to be able to handle + // the operation completing immediately upon allocation instead of + // asynchronously and the ref count lets us have the top-level call clean up. + iree_atomic_ref_count_t ref_count; + // Device this transfer operation is acting on. + iree_hal_device_t* device; + // Queue affinity all operations should be assigned. + iree_hal_queue_affinity_t queue_affinity; + // Used to associate tracing events with this worker. + IREE_TRACE(int32_t trace_id;) + + // Direction of the operation (read file->buffer or write buffer->file). + iree_hal_transfer_direction_t direction; + // Retained file resource. + iree_hal_file_t* file; + // Offset into the file where the operation begins. + uint64_t file_offset; + // Retained buffer resource. + iree_hal_buffer_t* buffer; + // Offset into the buffer where the operation begins. + iree_device_size_t buffer_offset; + // Total length of the operation. + iree_device_size_t length; + + // Sticky error status; when any worker fails this will be set to non-OK and + // when all workers end the staging buffer will be deallocated and the signal + // semaphores will be marked as failing. + iree_status_t error_status; + // Original user semaphores to signal at the end of the transfer operation. + // Contents are stored at the end of the struct. + iree_hal_semaphore_list_t signal_semaphore_list; + + // Shared staging buffer; contains storage for all workers. + // We avoid a subspan buffer here to reduce overheads. + iree_hal_buffer_t* staging_buffer; + iree_device_size_t staging_buffer_size; + + // Offset to where the transfer head is in the operation. + // Ranges from 0 at the start and length at the end. + // Workers use this to consume chunks of the operation. + iree_device_size_t transfer_head; + // Total number of chunks remaining in the transfer. + iree_host_size_t remaining_chunks; + + // Total number of workers participating in the operation. + iree_host_size_t worker_count; + // State for each worker in the operation. + // Stored at the end of the struct. + iree_hal_transfer_worker_t* workers; + // One bit per worker indicating whether they are live and ticking the loop. + // A worker exits by not rescheduling itself and clearing this bit. When all + // workers have exited the operation has completed. + // When reading workers exit after enqueuing their final transfer such that + // the final staging buffer dealloca can be asynchronously chained. + // When writing workers exit after flushing their final chunk to the file. + iree_hal_transfer_worker_bitmask_t live_workers; +} iree_hal_transfer_operation_t; + +static void iree_hal_transfer_operation_release( + iree_hal_transfer_operation_t* operation); +static void iree_hal_transfer_operation_destroy( + iree_hal_transfer_operation_t* operation); + +static iree_status_t iree_hal_transfer_operation_create( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_transfer_direction_t direction, iree_hal_file_t* file, + uint64_t file_offset, iree_hal_buffer_t* buffer, + iree_device_size_t buffer_offset, iree_device_size_t length, + iree_hal_file_transfer_options_t options, + iree_hal_transfer_operation_t** out_operation) { + IREE_ASSERT_ARGUMENT(out_operation); + *out_operation = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_t host_allocator = iree_hal_device_host_allocator(device); + + // Determine how many workers are required and their staging reservation. + iree_device_size_t worker_chunk_size = options.chunk_size; + if (worker_chunk_size == IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT) { + worker_chunk_size = iree_min(IREE_HAL_TRANSFER_CHUNK_SIZE, length); + } + iree_device_size_t total_chunk_count = + iree_device_size_ceil_div(length, worker_chunk_size); + iree_host_size_t worker_count = options.chunk_count; + if (worker_count == IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT) { + // Try to give each worker a couple chunks. + worker_count = (iree_host_size_t)iree_device_size_ceil_div( + total_chunk_count, IREE_HAL_TRANSFER_CHUNKS_PER_WORKER); + } + worker_count = + iree_min(worker_count, iree_min(IREE_HAL_TRANSFER_WORKER_LIMIT, + IREE_HAL_TRANSFER_WORKER_MAX_COUNT)); + + // Calculate total size of the structure with all its associated data. + iree_hal_transfer_operation_t* operation = NULL; + iree_host_size_t total_size = sizeof(*operation); + iree_host_size_t semaphores_offset = + iree_host_align(total_size, iree_max_align_t); + total_size = semaphores_offset + sizeof(signal_semaphore_list.semaphores[0]) * + signal_semaphore_list.count; + iree_host_size_t payload_values_offset = + iree_host_align(total_size, iree_max_align_t); + total_size = + payload_values_offset + sizeof(signal_semaphore_list.payload_values[0]) * + signal_semaphore_list.count; + iree_host_size_t worker_offset = + iree_host_align(total_size, iree_max_align_t); + total_size = worker_offset + sizeof(operation->workers[0]) * worker_count; + + // Allocate and initialize the struct. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, total_size, (void**)&operation)); + iree_atomic_ref_count_init(&operation->ref_count); + operation->device = device; + iree_hal_device_retain(device); + operation->queue_affinity = queue_affinity; + operation->direction = direction; + operation->file = file; + iree_hal_file_retain(file); + operation->file_offset = file_offset; + operation->buffer = buffer; + iree_hal_buffer_retain(buffer); + operation->buffer_offset = buffer_offset; + operation->length = length; + operation->staging_buffer_size = worker_count * worker_chunk_size; + operation->transfer_head = 0; + operation->remaining_chunks = (iree_host_size_t)total_chunk_count; + operation->worker_count = worker_count; + + // Assign all pointers to the struct suffix storage. + // We do this first so that if we have to free the struct we have valid + // pointers. + operation->signal_semaphore_list.count = signal_semaphore_list.count; + operation->signal_semaphore_list.semaphores = + (iree_hal_semaphore_t**)((uintptr_t)operation + semaphores_offset); + operation->signal_semaphore_list.payload_values = + (uint64_t*)((uintptr_t)operation + payload_values_offset); + operation->workers = + (iree_hal_transfer_worker_t*)((uintptr_t)operation + worker_offset); + + // Assign a unique ID we'll use to make it easier to track what individual + // steps are part of this transfer. + IREE_TRACE({ + static iree_atomic_int32_t next_trace_id = IREE_ATOMIC_VAR_INIT(0); + operation->trace_id = iree_atomic_fetch_add_int32( + &next_trace_id, 1, iree_memory_order_seq_cst); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation->trace_id); + }); + + // Retain each signal semaphore for ourselves as we don't know if the caller + // will hold them for the lifetime of the operation. + memcpy(operation->signal_semaphore_list.semaphores, + signal_semaphore_list.semaphores, + sizeof(signal_semaphore_list.semaphores[0]) * + signal_semaphore_list.count); + memcpy(operation->signal_semaphore_list.payload_values, + signal_semaphore_list.payload_values, + sizeof(signal_semaphore_list.payload_values[0]) * + signal_semaphore_list.count); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + iree_hal_semaphore_retain(signal_semaphore_list.semaphores[i]); + } + + // Initialize all workers. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < worker_count; ++i) { + iree_hal_transfer_worker_t* worker = &operation->workers[i]; + worker->operation = operation; + + // Assign a unique ID we'll use to make it easier to track what individual + // steps are part of this worker. It only needs to be unique within the + // operation. + IREE_TRACE(worker->trace_id = (int64_t)i); + + // View into the staging buffer where the worker keeps its memory. + worker->staging_buffer_offset = worker_chunk_size * i; + worker->staging_buffer_length = worker_chunk_size; + + // Create semaphore for tracking worker progress. + worker->pending_timepoint = 0ull; + status = iree_hal_semaphore_create(device, worker->pending_timepoint, + &worker->semaphore); + if (!iree_status_is_ok(status)) break; + } + + if (iree_status_is_ok(status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "worker count: "); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker_count); + IREE_TRACE_ZONE_APPEND_TEXT(z0, "worker chunk size: "); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker_chunk_size); + *out_operation = operation; + } else { + iree_hal_transfer_operation_release(operation); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_transfer_operation_retain( + iree_hal_transfer_operation_t* operation) { + if (IREE_LIKELY(operation)) { + iree_atomic_ref_count_inc(&operation->ref_count); + } +} + +static void iree_hal_transfer_operation_release( + iree_hal_transfer_operation_t* operation) { + if (IREE_LIKELY(operation) && + iree_atomic_ref_count_dec(&operation->ref_count) == 1) { + iree_hal_transfer_operation_destroy(operation); + } +} + +static void iree_hal_transfer_operation_destroy( + iree_hal_transfer_operation_t* operation) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation->trace_id); + iree_allocator_t host_allocator = + iree_hal_device_host_allocator(operation->device); + + // We don't want any pending loop operations when freeing as the loop event + // handlers will try to access the memory. + IREE_ASSERT(operation->live_workers == 0, "all workers must have exited"); + + for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { + iree_hal_semaphore_release(operation->workers[i].semaphore); + } + iree_hal_buffer_release(operation->staging_buffer); + for (iree_host_size_t i = 0; i < operation->signal_semaphore_list.count; + ++i) { + iree_hal_semaphore_release(operation->signal_semaphore_list.semaphores[i]); + } + iree_hal_buffer_release(operation->buffer); + iree_hal_file_release(operation->file); + iree_hal_device_release(operation->device); + iree_status_ignore(operation->error_status); + + iree_allocator_free(host_allocator, operation); + + IREE_TRACE_ZONE_END(z0); +} + +// Notifies listeners that the operation has completed and releases its memory. +// If this was a read then the staging buffer dealloca will be chained to the +// last asynchronous copies. In writes the last flush to the file happened +// synchronously so the dealloca happens immediately. +// +// Pre-condition: all workers have exited and there are no operations in flight. +// Post-condition: the operation is freed. +static void iree_hal_transfer_operation_notify_completion( + iree_hal_transfer_operation_t* operation) { + IREE_ASSERT_ARGUMENT(operation); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation->trace_id); + + // We can only free the operation if no workers have pending work. + IREE_ASSERT(operation->live_workers == 0, "no workers can be live"); + + // Deallocating the staging buffer can only happen after all workers have + // completed copies into/out-of it. In reads it's expected there are copies + // in-flight and we can wait on all worker semaphores. In writes the last + // flush to the file happened synchronously so we don't need to wait at all. + iree_hal_semaphore_list_t wait_semaphore_list = + iree_hal_semaphore_list_empty(); + if (operation->direction == IREE_HAL_TRANSFER_READ_FILE_TO_BUFFER) { + wait_semaphore_list.count = operation->worker_count; + wait_semaphore_list.semaphores = (iree_hal_semaphore_t**)iree_alloca( + wait_semaphore_list.count * sizeof(wait_semaphore_list.semaphores[0])); + wait_semaphore_list.payload_values = + (uint64_t*)iree_alloca(wait_semaphore_list.count * + sizeof(wait_semaphore_list.payload_values[0])); + for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { + iree_hal_transfer_worker_t* worker = &operation->workers[i]; + wait_semaphore_list.semaphores[i] = worker->semaphore; + wait_semaphore_list.payload_values[i] = worker->pending_timepoint; + } + } + + // When the dealloca completes signal the original semaphores passed in to the + // operation. If the transfer failed then we need to signal them all to + // failure after the dealloca so we use the same semaphores but set the + // failure payload. + iree_hal_semaphore_list_t signal_semaphore_list = + operation->signal_semaphore_list; + if (!iree_status_is_ok(operation->error_status)) { + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + signal_semaphore_list.payload_values[i] = + IREE_HAL_SEMAPHORE_FAILURE_VALUE; + } + } + + // Schedule staging buffer deallocation. + // Note that we need to do this even if the operation failed and we want it to + // be scheduled after any copies that may be in-flight (say worker 4 failed + // on a chunk but workers 0-3 succeeded). + iree_status_t status = iree_hal_device_queue_dealloca( + operation->device, operation->queue_affinity, wait_semaphore_list, + signal_semaphore_list, operation->staging_buffer); + + // If the dealloca failed we don't have a great way of letting anyone know. + // We'll just drop it on the floor for now and let the buffer be freed by + // reference counting. + iree_status_ignore(status); + + IREE_TRACE_ZONE_END(z0); +} + +// Exits the |worker| indicating it will process no more chunks. +// If this is the last worker to exit the transfer is considered completed. +// The first non-OK |status| provided will be set as a stick error on the +// |operation| and all workers will check it and exit themselves asynchronously. +// +// NOTE: this may end the entire operation and free the operation (and worker) +// memory. Callers must not touch either |operation| or |worker| after calling +// this method. +static iree_status_t iree_hal_transfer_worker_exit( + iree_hal_transfer_operation_t* operation, + iree_hal_transfer_worker_t* worker, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker->trace_id); + + // Check if this is the first failure of an operation and set the error bit. + // Otherwise we ignore the error here as it's probably just telling us the + // worker has aborted. + if (iree_status_is_ok(operation->error_status) && + !iree_status_is_ok(status)) { + operation->error_status = status; + } else { + iree_status_ignore(status); + } + + // Clear the worker live bit and see if there are any more workers live. So + // long as there is at least one we need to keep the operation running. + iree_host_size_t worker_index = + (iree_host_size_t)(worker - operation->workers); + operation->live_workers &= ~(1ull << worker_index); + if (operation->live_workers > 0) { + // Other workers are still live - this is just one worker exiting by not + // rescheduling itself. + iree_hal_transfer_operation_release(operation); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + // This was the last worker; the operation has completed! + iree_hal_transfer_operation_notify_completion(operation); + + // Free the operation - the dealloca (and maybe even some copies) may still be + // in-flight but all on the device side and not using any resources on the + // operation. + iree_hal_transfer_operation_release(operation); + // NOTE: at this point the worker may have freed itself and its memory can + // no longer be used! + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); // always ok; just for convenience. +} + +static iree_status_t iree_hal_transfer_worker_copy_file_to_buffer( + void* user_data, iree_loop_t loop, iree_status_t status) { + iree_hal_transfer_worker_t* worker = (iree_hal_transfer_worker_t*)user_data; + iree_hal_transfer_operation_t* operation = worker->operation; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker->trace_id); + + // Bail immediately if the operation has failed. + if (!iree_status_is_ok(status) || + !iree_status_is_ok(operation->error_status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: loop error"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, status); + } + + // Early-exit if we're out of chunks to process. + // This can happen with some loop implementations that run things in batches. + if (operation->remaining_chunks == 0) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "exit: no remaining chunks"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, iree_ok_status()); + } + + // Grab a piece of the transfer to operate on. + --operation->remaining_chunks; + iree_device_size_t transfer_offset = operation->transfer_head; + iree_device_size_t transfer_length = iree_min( + operation->length - transfer_offset, worker->staging_buffer_length); + IREE_ASSERT(transfer_length > 0, + "should not have ticked if there was no work to do"); + operation->transfer_head += transfer_length; + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_length); + + // Timeline increments by one. + uint64_t wait_timepoint = worker->pending_timepoint; + iree_hal_semaphore_list_t wait_semaphore_list = { + .count = 1, + .semaphores = &worker->semaphore, + .payload_values = &wait_timepoint, + }; + uint64_t signal_timepoint = ++worker->pending_timepoint; + iree_hal_semaphore_list_t signal_semaphore_list = { + .count = 1, + .semaphores = &worker->semaphore, + .payload_values = &signal_timepoint, + }; + + // Track the pending copy operation so we know where to place it in the + // buffer. + worker->pending_transfer_offset = transfer_offset; + worker->pending_transfer_length = transfer_length; + + // Synchronously copy the contents from the file to the staging buffer. + status = iree_hal_file_read( + operation->file, operation->file_offset + worker->pending_transfer_offset, + operation->staging_buffer, worker->staging_buffer_offset, + worker->pending_transfer_length); + + // Issue asynchronous copy from the staging buffer into the target buffer. + if (iree_status_is_ok(status)) { + status = iree_hal_device_queue_copy( + operation->device, operation->queue_affinity, wait_semaphore_list, + signal_semaphore_list, operation->staging_buffer, + worker->staging_buffer_offset, operation->buffer, + operation->buffer_offset + transfer_offset, transfer_length); + } + + // Wait for the copy to complete and tick again if we expect there to be more + // work. If there are no more chunks to copy (or they are spoken for by other + // live workers) we can avoid the loop wait and exit such that the dealloca + // can chain on to the copy operations. + if (iree_status_is_ok(status)) { + if (iree_hal_transfer_worker_live_count(operation->live_workers) > + operation->remaining_chunks) { + // Enough workers to cover all remaining chunks so we can exit now and + // avoid an additional host wake (+ latency) by the loop event. + IREE_TRACE_ZONE_APPEND_TEXT(z0, + "exit: remaining chunks covered by workers"); + status = iree_hal_transfer_worker_exit(operation, worker, status); + } else { + status = iree_loop_wait_one( + loop, + iree_hal_semaphore_await(worker->semaphore, + worker->pending_timepoint), + iree_infinite_timeout(), iree_hal_transfer_worker_copy_file_to_buffer, + worker); + } + } + + if (!iree_status_is_ok(status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: copy/wait failure"); + status = iree_hal_transfer_worker_exit(operation, worker, status); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Begins the transfer operation after |wait_semaphore_list| is satisfied. +// Note that if this fails then the transfer never started and it's safe to +// immediately tear down. +static iree_status_t iree_hal_transfer_operation_launch_read( + iree_hal_transfer_operation_t* operation, + iree_hal_semaphore_list_t wait_semaphore_list, iree_loop_t loop) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_ASSERT(operation->direction == IREE_HAL_TRANSFER_READ_FILE_TO_BUFFER); + + // Staging buffers get allocated based on the direction we are transferring. + // This optimizes for access patterns such as sequential writes from the host + // when staging into the buffer and sequential cached reads from the host when + // staging out of the buffer. + iree_hal_buffer_params_t staging_buffer_params = { + .access = IREE_HAL_MEMORY_ACCESS_ALL, + // TODO(benvanik): make staging alignment an option/device query? + .min_alignment = 64, + .queue_affinity = operation->queue_affinity, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED | + IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE, + }; + + // Queue the staging buffer allocation. + // When it completes we'll do the first host->device copy via mapping. + iree_hal_semaphore_list_t alloca_semaphore_list = { + .count = operation->worker_count, + .semaphores = + iree_alloca(sizeof(iree_hal_semaphore_t*) * operation->worker_count), + .payload_values = iree_alloca(sizeof(uint64_t) * operation->worker_count), + }; + for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { + iree_hal_transfer_worker_t* worker = &operation->workers[i]; + alloca_semaphore_list.semaphores[i] = worker->semaphore; + alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + } + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_device_queue_alloca( + operation->device, operation->queue_affinity, wait_semaphore_list, + alloca_semaphore_list, IREE_HAL_ALLOCATOR_POOL_DEFAULT, + staging_buffer_params, operation->staging_buffer_size, + &operation->staging_buffer)); + + // After the alloca completes each worker will be at the same starting point. + // We'll wait on each and start the worker-specific coroutines. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t worker_index = 0; + worker_index < operation->worker_count; ++worker_index) { + iree_hal_transfer_worker_t* worker = &operation->workers[worker_index]; + operation->live_workers |= 1ull << worker_index; + iree_hal_transfer_operation_retain(operation); + status = iree_loop_wait_one( + loop, + iree_hal_semaphore_await(worker->semaphore, worker->pending_timepoint), + iree_infinite_timeout(), iree_hal_transfer_worker_copy_file_to_buffer, + worker); + if (!iree_status_is_ok(status)) { + operation->live_workers &= ~(1ull << worker_index); + iree_hal_transfer_operation_release(operation); + break; + } + + // It's possible that the entire operation completed inline. + if (operation->remaining_chunks == 0) break; + } + if (!iree_status_is_ok(status)) { + // Failed to wait on one of the workers. This is a fatal error but we may + // have already waited on some workers and need to instead set the sticky + // error flag so that when any complete they stop processing. + operation->error_status = status; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); // return ok as loop is fine but operation is not +} + +static iree_status_t iree_hal_transfer_worker_copy_staging_to_file( + void* user_data, iree_loop_t loop, iree_status_t status); + +static iree_status_t iree_hal_transfer_worker_copy_buffer_to_staging( + iree_hal_transfer_operation_t* operation, + iree_hal_transfer_worker_t* worker, iree_loop_t loop) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker->trace_id); + + // If there's been an error we bail. + if (!iree_status_is_ok(operation->error_status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: error bit set"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, iree_ok_status()); + } + + // Grab a piece of the transfer to operate on. + IREE_ASSERT(operation->remaining_chunks > 0, + "should not have ticked if there was no work to do"); + --operation->remaining_chunks; + iree_device_size_t transfer_offset = operation->transfer_head; + iree_device_size_t transfer_length = iree_min( + operation->length - transfer_offset, worker->staging_buffer_length); + IREE_ASSERT(transfer_length > 0, + "should not have ticked if there was no work to do"); + operation->transfer_head += transfer_length; + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_length); + + // Timeline increments by one. + iree_hal_semaphore_list_t wait_semaphore_list = { + .count = 1, + .semaphores = &worker->semaphore, + .payload_values = &worker->pending_timepoint, + }; + ++worker->pending_timepoint; + iree_hal_semaphore_list_t signal_semaphore_list = { + .count = 1, + .semaphores = &worker->semaphore, + .payload_values = &worker->pending_timepoint, + }; + + // Track the pending copy operation so we know where to place it in the file. + worker->pending_transfer_offset = transfer_offset; + worker->pending_transfer_length = transfer_length; + + // Issue an asynchronous copy from the source buffer to the staging buffer. + iree_status_t status = iree_hal_device_queue_copy( + operation->device, operation->queue_affinity, wait_semaphore_list, + signal_semaphore_list, operation->buffer, + operation->buffer_offset + transfer_offset, operation->staging_buffer, + worker->staging_buffer_offset, transfer_length); + + // Wait for the copy to complete so we can write it to the file. + if (iree_status_is_ok(status)) { + status = iree_loop_wait_one( + loop, + iree_hal_semaphore_await(worker->semaphore, worker->pending_timepoint), + iree_infinite_timeout(), iree_hal_transfer_worker_copy_staging_to_file, + worker); + } + + if (!iree_status_is_ok(status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: copy/wait failure"); + status = iree_hal_transfer_worker_exit(operation, worker, status); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_transfer_worker_copy_staging_to_file( + void* user_data, iree_loop_t loop, iree_status_t status) { + iree_hal_transfer_worker_t* worker = (iree_hal_transfer_worker_t*)user_data; + iree_hal_transfer_operation_t* operation = worker->operation; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)worker->trace_id); + + // Bail immediately if the operation has failed. + if (!iree_status_is_ok(status) || + !iree_status_is_ok(operation->error_status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: loop error"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, status); + } + + // Synchronously copy the contents from the staging buffer to the file. + status = iree_hal_file_write( + operation->file, operation->file_offset + worker->pending_transfer_offset, + operation->staging_buffer, worker->staging_buffer_offset, + worker->pending_transfer_length); + + if (iree_status_is_ok(status) && operation->remaining_chunks == 0) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "exit: no more chunks remaining to write"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, iree_ok_status()); + } + + if (!iree_status_is_ok(status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "bail: file write error"); + IREE_TRACE_ZONE_END(z0); + return iree_hal_transfer_worker_exit(operation, worker, status); + } + + IREE_TRACE_ZONE_END(z0); + + // Tail call: tick the worker so that it transfers another chunk. + return iree_hal_transfer_worker_copy_buffer_to_staging(operation, worker, + loop); +} + +// Begins the transfer operation after |wait_semaphore_list| is satisfied. +// Note that if this fails then the transfer never started and it's safe to +// immediately tear down. +static iree_status_t iree_hal_transfer_operation_launch_write( + iree_hal_transfer_operation_t* operation, + iree_hal_semaphore_list_t wait_semaphore_list, iree_loop_t loop) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)operation->trace_id); + IREE_ASSERT(operation->direction == IREE_HAL_TRANSFER_WRITE_BUFFER_TO_FILE); + + // Staging buffers get allocated based on the direction we are transferring. + // This optimizes for access patterns such as sequential writes from the host + // when staging into the buffer and sequential cached reads from the host when + // staging out of the buffer. + iree_hal_buffer_params_t staging_buffer_params = { + .access = IREE_HAL_MEMORY_ACCESS_ALL, + // TODO(benvanik): make staging alignment an option/device query? + .min_alignment = 64, + .queue_affinity = operation->queue_affinity, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST | + IREE_HAL_MEMORY_TYPE_HOST_CACHED | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED | + IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_RANDOM, + }; + + // Queue the staging buffer allocation. + // When it completes we'll signal each worker to start its first transfer. + iree_hal_semaphore_list_t alloca_semaphore_list = { + .count = operation->worker_count, + .semaphores = + iree_alloca(sizeof(iree_hal_semaphore_t*) * operation->worker_count), + .payload_values = iree_alloca(sizeof(uint64_t) * operation->worker_count), + }; + for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { + iree_hal_transfer_worker_t* worker = &operation->workers[i]; + alloca_semaphore_list.semaphores[i] = worker->semaphore; + alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + } + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_device_queue_alloca( + operation->device, operation->queue_affinity, wait_semaphore_list, + alloca_semaphore_list, IREE_HAL_ALLOCATOR_POOL_DEFAULT, + staging_buffer_params, operation->staging_buffer_size, + &operation->staging_buffer)); + + // After the alloca completes each worker will be at the same starting point. + // We'll wait on each and start the worker-specific coroutines. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t worker_index = 0; + worker_index < operation->worker_count; ++worker_index) { + iree_hal_transfer_worker_t* worker = &operation->workers[worker_index]; + operation->live_workers |= 1ull << worker_index; + iree_hal_transfer_operation_retain(operation); + + // Issue the initial asynchronous copy from the source buffer to the worker + // chunk. This will wait for the alloca to complete so that the staging + // buffer is available for use. After the copy completes the worker will + // tick itself so long as there are chunks remaining to write. + status = iree_hal_transfer_worker_copy_buffer_to_staging(operation, worker, + loop); + if (!iree_status_is_ok(status)) break; + + // It's possible that the entire operation completed inline. + if (operation->remaining_chunks == 0) break; + } + if (!iree_status_is_ok(status)) { + // Failed to wait on one of the workers. This is a fatal error but we may + // have already waited on some workers and need to instead set the sticky + // error flag so that when any complete they stop processing. + operation->error_status = status; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); // return ok as loop is fine but operation is not +} + +//===----------------------------------------------------------------------===// +// Memory file IO API +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_file_validate_access( + iree_hal_file_t* file, iree_hal_memory_access_t required_access) { + const iree_hal_memory_access_t allowed_access = + iree_hal_file_allowed_access(file); + if (IREE_LIKELY(iree_all_bits_set(allowed_access, required_access))) { + return iree_ok_status(); + } +#if IREE_STATUS_MODE + iree_bitfield_string_temp_t temp0, temp1; + iree_string_view_t allowed_access_str = + iree_hal_memory_access_format(allowed_access, &temp0); + iree_string_view_t required_access_str = + iree_hal_memory_access_format(required_access, &temp1); + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "file operation cannot be performed; file allows %.*s, operation " + "requires %.*s", + (int)allowed_access_str.size, allowed_access_str.data, + (int)required_access_str.size, required_access_str.data); +#else + return iree_make_status(IREE_STATUS_PERMISSION_DENIED); +#endif // IREE_STATUS_MODE +} + +IREE_API_EXPORT iree_status_t iree_hal_device_queue_read_streaming( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags, + iree_hal_file_transfer_options_t options) { + IREE_RETURN_IF_ERROR( + iree_hal_file_validate_access(source_file, IREE_HAL_MEMORY_ACCESS_READ)); + + // If the file implicitly supports device transfer then we can simply issue a + // device copy. + iree_hal_buffer_t* storage_buffer = iree_hal_file_storage_buffer(source_file); + if (storage_buffer) { + return iree_hal_device_queue_copy( + device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + storage_buffer, (iree_device_size_t)source_offset, target_buffer, + target_offset, length); + } + + // Allocate full transfer operation. + iree_hal_transfer_operation_t* operation = NULL; + IREE_RETURN_IF_ERROR(iree_hal_transfer_operation_create( + device, queue_affinity, signal_semaphore_list, + IREE_HAL_TRANSFER_READ_FILE_TO_BUFFER, source_file, source_offset, + target_buffer, target_offset, length, options, &operation)); + + // Kick off the streaming transfer. + // This will queue allocation of the staging buffer and then issue one or more + // copy commands. The operation will manage its own lifetime and emit errors + // as part of signal semaphore failures. + iree_status_t status = iree_hal_transfer_operation_launch_read( + operation, wait_semaphore_list, options.loop); + + iree_hal_transfer_operation_release(operation); + + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_device_queue_write_streaming( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags, + iree_hal_file_transfer_options_t options) { + // EXPERIMENTAL: assume memory files only today (as that's all we have). + IREE_RETURN_IF_ERROR( + iree_hal_file_validate_access(target_file, IREE_HAL_MEMORY_ACCESS_WRITE)); + + // If the file implicitly supports device transfer then we can simply issue a + // device copy. + iree_hal_buffer_t* storage_buffer = iree_hal_file_storage_buffer(target_file); + if (storage_buffer) { + return iree_hal_device_queue_copy( + device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, storage_buffer, + (iree_device_size_t)target_offset, length); + } + + // Allocate full transfer operation. + iree_hal_transfer_operation_t* operation = NULL; + IREE_RETURN_IF_ERROR(iree_hal_transfer_operation_create( + device, queue_affinity, signal_semaphore_list, + IREE_HAL_TRANSFER_WRITE_BUFFER_TO_FILE, target_file, target_offset, + source_buffer, source_offset, length, options, &operation)); + + // Kick off the streaming transfer. + // This will queue allocation of the staging buffer and then issue one or more + // copy commands. The operation will manage its own lifetime and emit errors + // as part of signal semaphore failures. + iree_status_t status = iree_hal_transfer_operation_launch_write( + operation, wait_semaphore_list, options.loop); + + iree_hal_transfer_operation_release(operation); + + return status; +} diff --git a/runtime/src/iree/hal/utils/file_transfer.h b/runtime/src/iree/hal/utils/file_transfer.h new file mode 100644 index 000000000000..cf8509913762 --- /dev/null +++ b/runtime/src/iree/hal/utils/file_transfer.h @@ -0,0 +1,93 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_UTILS_FILE_TRANSFER_H_ +#define IREE_HAL_UTILS_FILE_TRANSFER_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Generic file transfer IO implementation +//===----------------------------------------------------------------------===// + +#define IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT 0 +#define IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT 0 + +// Options for file-based transfer operations. +typedef struct iree_hal_file_transfer_options_t { + // Loop to use for asynchronous host operations. If inline then the transfer + // will run synchronously with the caller. + iree_loop_t loop; + // Total number of staging buffer chunks to allocate. + // Setting to >1 will allow for overlapped staging and transfer at the cost + // of additional staging buffer memory consumption. + // IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT can be used to have the + // implementation select a chunk size based on whether the device can benefit + // from overlapping staging. + iree_device_size_t chunk_count; + // Maximum size of chunks in bytes. The size may be adjusted to meet alignment + // requirements of the implementation. + // IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT can be used to have the + // implementation select a chunk size based on the size of the transfer. + iree_device_size_t chunk_size; +} iree_hal_file_transfer_options_t; + +// EXPERIMENTAL: eventually we'll focus this only on emulating support where +// otherwise unavailable. For now no HAL targets support files and all use this. +// +// Performs a streaming read of |source_file| into |target_buffer| using +// host-based staging buffers. This implementation may require staging buffers +// in which case |options.chunk_size| specifies the maximum size in bytes of +// each chunk and |options.chunk_count| specifies how many chunks will be +// allocated at once. +// +// The provided |options.loop| is used for any asynchronous host operations +// performed as part of the transfer. +// +// WARNING: this only works with memory files as created via +// iree_hal_memory_file_wrap. +IREE_API_EXPORT iree_status_t iree_hal_device_queue_read_streaming( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags, + iree_hal_file_transfer_options_t options); + +// EXPERIMENTAL: eventually we'll focus this only on emulating support where +// otherwise unavailable. For now no HAL targets support files and all use this. +// +// Performs a streaming write of |source_buffer| into |target_file| using +// host-based staging buffers. This implementation may require staging buffers +// in which case |options.chunk_size| specifies the maximum size in bytes of +// each chunk and |options.chunk_count| specifies how many chunks will be +// allocated at once. +// +// The provided |options.loop| is used for any asynchronous host operations +// performed as part of the transfer. +// +// WARNING: this only works with memory files as created via +// iree_hal_memory_file_wrap. +IREE_API_EXPORT iree_status_t iree_hal_device_queue_write_streaming( + iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags, + iree_hal_file_transfer_options_t options); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_UTILS_FILE_TRANSFER_H_ diff --git a/runtime/src/iree/hal/utils/memory_file.c b/runtime/src/iree/hal/utils/memory_file.c new file mode 100644 index 000000000000..fd6b7d2e98a2 --- /dev/null +++ b/runtime/src/iree/hal/utils/memory_file.c @@ -0,0 +1,354 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/utils/memory_file.h" + +//===----------------------------------------------------------------------===// +// Configuration +//===----------------------------------------------------------------------===// + +// TODO(benvanik): make these either compile-time configuration options so we +// can prune code paths or flags (somehow). + +// When 1 a fast-path for importable memory will be used to avoid staging. +#if !defined(IREE_HAL_MEMORY_FILE_CAN_IMPORT) +#define IREE_HAL_MEMORY_FILE_CAN_IMPORT 1 +#endif // !IREE_HAL_MEMORY_FILE_CAN_IMPORT + +//===----------------------------------------------------------------------===// +// iree_hal_memory_file_storage_t +//===----------------------------------------------------------------------===// + +// Reference-counted storage for memory file contents. +// This allows both the memory file and any intermediate/staging buffers that +// may reference it to keep the underlying storage live and not create cycles. +typedef struct iree_hal_memory_file_storage_t { + // Reference count for this storage instance. + iree_atomic_ref_count_t ref_count; + // Used to allocate this structure. + iree_allocator_t host_allocator; + // Host memory contents, unowned. + iree_byte_span_t contents; + // Called on destruction to allow for creators to manage lifetime. + iree_hal_file_release_callback_t release_callback; +} iree_hal_memory_file_storage_t; + +static iree_status_t iree_hal_memory_file_storage_create( + iree_byte_span_t contents, + iree_hal_file_release_callback_t release_callback, + iree_allocator_t host_allocator, + iree_hal_memory_file_storage_t** out_storage) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_memory_file_storage_t* storage = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*storage), + (void**)&storage)); + iree_atomic_ref_count_init(&storage->ref_count); + storage->host_allocator = host_allocator; + storage->contents = contents; + storage->release_callback = release_callback; + + *out_storage = storage; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_memory_file_storage_destroy( + iree_hal_memory_file_storage_t* storage) { + IREE_ASSERT_ARGUMENT(storage); + IREE_TRACE_ZONE_BEGIN(z0); + iree_allocator_t host_allocator = storage->host_allocator; + + if (storage->release_callback.fn) { + storage->release_callback.fn(storage->release_callback.user_data); + } + + iree_allocator_free(host_allocator, storage); + + IREE_TRACE_ZONE_END(z0); +} + +static void iree_hal_memory_file_storage_retain( + iree_hal_memory_file_storage_t* storage) { + if (IREE_LIKELY(storage)) { + iree_atomic_ref_count_inc(&storage->ref_count); + } +} + +static void iree_hal_memory_file_storage_release( + iree_hal_memory_file_storage_t* storage) { + if (IREE_LIKELY(storage) && + iree_atomic_ref_count_dec(&storage->ref_count) == 1) { + iree_hal_memory_file_storage_destroy(storage); + } +} + +//===----------------------------------------------------------------------===// +// iree_hal_memory_file_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_memory_file_t { + iree_hal_resource_t resource; + // Used to allocate this structure. + iree_allocator_t host_allocator; + // Allowed access bits. + iree_hal_memory_access_t access; + // Underlying storage container, retained. + iree_hal_memory_file_storage_t* storage; + // Optional imported buffer if it was possible to do so. + // Not all implementations and not all buffers can be imported. + iree_hal_buffer_t* imported_buffer; +} iree_hal_memory_file_t; + +static const iree_hal_file_vtable_t iree_hal_memory_file_vtable; + +static iree_hal_memory_file_t* iree_hal_memory_file_cast( + iree_hal_file_t* IREE_RESTRICT base_value) { + return (iree_hal_memory_file_t*)base_value; +} + +static void iree_hal_memory_file_try_import_buffer( + iree_hal_memory_file_t* file, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, iree_byte_span_t contents, + iree_hal_allocator_t* device_allocator); + +IREE_API_EXPORT iree_status_t iree_hal_memory_file_wrap( + iree_hal_queue_affinity_t queue_affinity, iree_hal_memory_access_t access, + iree_byte_span_t contents, + iree_hal_file_release_callback_t release_callback, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, + iree_hal_file_t** out_file) { + IREE_ASSERT_ARGUMENT(out_file); + *out_file = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)contents.data_length); + + // Note that iree_device_size_t (for device offsets/sizes) may be smaller than + // iree_host_size_t (for host offsets/sizes) - if so we need to ensure the + // bytes passed in will still fit in iree_device_size_t. + if (contents.data_length > IREE_DEVICE_SIZE_MAX) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "device size too small to represent host contents"); + } + + // Allocate file handle; this just holds a reference to the storage and + // (optionally) the imported buffer. + iree_hal_memory_file_t* file = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*file), (void**)&file)); + iree_hal_resource_initialize(&iree_hal_memory_file_vtable, &file->resource); + file->host_allocator = host_allocator; + file->access = access; + + // Create the underlying storage container that we use to manage the storage + // lifetime independently from the file lifetime. + iree_status_t status = iree_hal_memory_file_storage_create( + contents, release_callback, host_allocator, &file->storage); + +#if !IREE_HAL_MEMORY_FILE_CAN_IMPORT + // Importing disabled; useful for testing the slow path. + device_allocator = NULL; +#endif // IREE_HAL_MEMORY_FILE_CAN_IMPORT + + // Try importing the buffer as a host-local staging buffer. + // This won't always succeed due to device, platform, HAL implementation, or + // buffer limitations but if it does we can avoid staging ourselves during + // streaming and directly read/write the memory via transfer commands. + if (iree_status_is_ok(status) && device_allocator) { + iree_hal_memory_file_try_import_buffer(file, queue_affinity, access, + contents, device_allocator); + } + + if (iree_status_is_ok(status)) { + *out_file = (iree_hal_file_t*)file; + } else { + iree_hal_file_release((iree_hal_file_t*)file); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_memory_file_destroy( + iree_hal_file_t* IREE_RESTRICT base_file) { + iree_hal_memory_file_t* file = iree_hal_memory_file_cast(base_file); + iree_allocator_t host_allocator = file->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + if (file->imported_buffer) { + iree_hal_buffer_release(file->imported_buffer); + file->imported_buffer = NULL; + } + + iree_hal_memory_file_storage_release(file->storage); + + iree_allocator_free(host_allocator, file); + + IREE_TRACE_ZONE_END(z0); +} + +// Releases the underlying file storage after the buffer using it is released. +static void iree_hal_memory_file_buffer_release(void* user_data, + iree_hal_buffer_t* buffer) { + iree_hal_memory_file_storage_release( + (iree_hal_memory_file_storage_t*)user_data); +} + +// Tries to import |contents| as a device-accessible HAL buffer. +// If this succeeds we can fast-path copies without needing to allocate any +// staging buffers and directly make use of DMA resources. If it fails we fall +// back to staging from host memory ourselves. +static void iree_hal_memory_file_try_import_buffer( + iree_hal_memory_file_t* file, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, iree_byte_span_t contents, + iree_hal_allocator_t* device_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_params_t staging_buffer_params = { + .access = access, + .queue_affinity = queue_affinity, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .usage = IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED | + IREE_HAL_BUFFER_USAGE_MAPPING_ACCESS_SEQUENTIAL_WRITE | + (iree_any_bit_set(access, IREE_HAL_MEMORY_ACCESS_READ) + ? IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE + : 0) | + (iree_any_bit_set(access, IREE_HAL_MEMORY_ACCESS_WRITE) + ? IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET + : 0), + }; + + iree_hal_external_buffer_t external_buffer = { + .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION, + .flags = 0, + .size = (iree_device_size_t)contents.data_length, + .handle = + { + .host_allocation = + { + .ptr = contents.data, + }, + }, + }; + + // NOTE: we make the buffer retain the underlying storage. + // We have to handle the case where the import fails and we need to balance + // the retain we did below. + iree_hal_buffer_release_callback_t imported_release_callback = { + .fn = iree_hal_memory_file_buffer_release, + .user_data = file->storage, + }; + iree_hal_memory_file_storage_retain(file->storage); + iree_status_t status = iree_hal_allocator_import_buffer( + device_allocator, staging_buffer_params, &external_buffer, + imported_release_callback, &file->imported_buffer); + if (!iree_status_is_ok(status)) { + iree_hal_memory_file_storage_release(file->storage); + } + + IREE_TRACE({ + if (iree_status_is_ok(status)) { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "import success"); + } else { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "import failure"); + IREE_TRACE_ZONE_APPEND_TEXT( + z0, iree_status_code_string(iree_status_code(status))); + } + }); + + IREE_TRACE_ZONE_END(z0); + iree_status_ignore(status); +} + +static const iree_hal_file_vtable_t iree_hal_memory_file_vtable = { + .destroy = iree_hal_memory_file_destroy, +}; + +//===----------------------------------------------------------------------===// +// EXPERIMENTAL: synchronous file read/write API +//===----------------------------------------------------------------------===// +// This is incomplete and may not appear like this on the iree_hal_file_t +// vtable; this does work for memory files though. + +IREE_API_EXPORT iree_hal_memory_access_t +iree_hal_file_allowed_access(iree_hal_file_t* base_file) { + IREE_ASSERT_ARGUMENT(base_file); + + // EXPERIMENTAL: today only memory files. This should be on the file vtable + // (if supported - not all implementations need to support it). + iree_hal_memory_file_t* file = (iree_hal_memory_file_t*)base_file; + + return file->access; +} + +IREE_API_EXPORT uint64_t iree_hal_file_length(iree_hal_file_t* base_file) { + IREE_ASSERT_ARGUMENT(base_file); + + // EXPERIMENTAL: today only memory files. This should be on the file vtable + // (if supported - not all implementations need to support it). + iree_hal_memory_file_t* file = (iree_hal_memory_file_t*)base_file; + + return file->storage->contents.data_length; +} + +IREE_API_EXPORT iree_hal_buffer_t* iree_hal_file_storage_buffer( + iree_hal_file_t* base_file) { + IREE_ASSERT_ARGUMENT(base_file); + + // EXPERIMENTAL: today only memory files. This should be on the file vtable + // (if supported - not all implementations need to support it). + iree_hal_memory_file_t* file = (iree_hal_memory_file_t*)base_file; + + return file->imported_buffer; +} + +IREE_API_EXPORT iree_status_t iree_hal_file_read( + iree_hal_file_t* base_file, uint64_t file_offset, iree_hal_buffer_t* buffer, + iree_device_size_t buffer_offset, iree_device_size_t length) { + IREE_ASSERT_ARGUMENT(base_file); + IREE_ASSERT_ARGUMENT(buffer); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, file_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)buffer_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)length); + + // EXPERIMENTAL: today only memory files. This should be on the file vtable + // (if supported - not all implementations need to support it). + iree_hal_memory_file_t* file = (iree_hal_memory_file_t*)base_file; + + // Copy from the file contents to the staging buffer. + iree_byte_span_t file_contents = file->storage->contents; + iree_status_t status = iree_hal_buffer_map_write( + buffer, buffer_offset, file_contents.data + file_offset, length); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_file_write( + iree_hal_file_t* base_file, uint64_t file_offset, iree_hal_buffer_t* buffer, + iree_device_size_t buffer_offset, iree_device_size_t length) { + IREE_ASSERT_ARGUMENT(base_file); + IREE_ASSERT_ARGUMENT(buffer); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, file_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)buffer_offset); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)length); + + // EXPERIMENTAL: today only memory files. This should be on the file vtable + // (if supported - not all implementations need to support it). + iree_hal_memory_file_t* file = (iree_hal_memory_file_t*)base_file; + + // Copy from the staging buffer to the file contents. + iree_byte_span_t file_contents = file->storage->contents; + iree_status_t status = iree_hal_buffer_map_read( + buffer, buffer_offset, file_contents.data + file_offset, length); + + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/runtime/src/iree/hal/utils/memory_file.h b/runtime/src/iree/hal/utils/memory_file.h new file mode 100644 index 000000000000..967d31cc5d2d --- /dev/null +++ b/runtime/src/iree/hal/utils/memory_file.h @@ -0,0 +1,73 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_UTILS_MEMORY_FILE_H_ +#define IREE_HAL_UTILS_MEMORY_FILE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_memory_file_t +//===----------------------------------------------------------------------===// + +// Creates a file handle backed by |contents| without copying the data. +// |release_callback| will be called when the file is destroyed. +// If the memory can be imported into a usable staging buffer |device_allocator| +// will be used to do so. +IREE_API_EXPORT iree_status_t iree_hal_memory_file_wrap( + iree_hal_queue_affinity_t queue_affinity, iree_hal_memory_access_t access, + iree_byte_span_t contents, + iree_hal_file_release_callback_t release_callback, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, + iree_hal_file_t** out_file); + +//===----------------------------------------------------------------------===// +// EXPERIMENTAL: synchronous file read/write API +//===----------------------------------------------------------------------===// +// This is incomplete and may not appear like this on the iree_hal_file_t +// vtable; this does work for memory files though. + +// Returns the memory access allowed to the file. +// This may be more strict than the original file handle backing the resource +// if for example we want to prevent particular users from mutating the file. +IREE_API_EXPORT iree_hal_memory_access_t +iree_hal_file_allowed_access(iree_hal_file_t* file); + +// Returns the total accessible range of the file. +// This may be a portion of the original file backing this handle. +IREE_API_EXPORT uint64_t iree_hal_file_length(iree_hal_file_t* file); + +// Returns an optional device-accessible storage buffer representing the file. +// Available if the implementation is able to perform import/address-space +// mapping/etc such that device-side transfers can directly access the resources +// as if they were a normal device buffer. +IREE_API_EXPORT iree_hal_buffer_t* iree_hal_file_storage_buffer( + iree_hal_file_t* file); + +// TODO(benvanik): truncate/extend? (both can be tricky with async) + +// Synchronously reads a segment of |file| into |buffer|. +// Blocks the caller until completed. Buffers are always host mappable. +IREE_API_EXPORT iree_status_t iree_hal_file_read( + iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer, + iree_device_size_t buffer_offset, iree_device_size_t length); + +// Synchronously writes a segment of |buffer| into |file|. +// Blocks the caller until completed. Buffers are always host mappable. +IREE_API_EXPORT iree_status_t iree_hal_file_write( + iree_hal_file_t* file, uint64_t file_offset, iree_hal_buffer_t* buffer, + iree_device_size_t buffer_offset, iree_device_size_t length); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_UTILS_MEMORY_FILE_H_ diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index e06e9227e1ad..c9eae4bf468c 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -24,9 +24,8 @@ // clang-format off -EXPORT_FN("allocator.allocate", iree_hal_module_allocator_allocate, riiI, r) -EXPORT_FN("allocator.allocate.initialized", iree_hal_module_allocator_allocate_initialized, riirII, r) -EXPORT_FN("allocator.map.byte_buffer", iree_hal_module_allocator_map_byte_buffer, riiirII, r) +EXPORT_FN("allocator.allocate", iree_hal_module_allocator_allocate, rIiiI, r) +EXPORT_FN("allocator.import", iree_hal_module_allocator_import, riIiirII, r) EXPORT_FN("buffer.assert", iree_hal_module_buffer_assert, rrrIii, v) EXPORT_FN("buffer.length", iree_hal_module_buffer_length, r, I) @@ -69,7 +68,10 @@ EXPORT_FN("device.queue.alloca", iree_hal_module_device_queue_alloca, rIrriiiI, EXPORT_FN("device.queue.dealloca", iree_hal_module_device_queue_dealloca, rIrrr, v) EXPORT_FN("device.queue.execute", iree_hal_module_device_queue_execute, rIrrCrD, v) EXPORT_FN("device.queue.flush", iree_hal_module_device_queue_flush, rI, v) +EXPORT_FN("device.queue.read", iree_hal_module_device_queue_read, rIrrrIrIIi, v) +EXPORT_FN("device.queue.write", iree_hal_module_device_queue_write, rIrrrIrIIi, v) +EXPORT_FN("ex.file.from_memory", iree_hal_module_ex_file_from_memory, rIirIIi, r) EXPORT_FN("ex.shared_device", iree_hal_module_ex_shared_device, v, r) EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index d778ef1c1f5c..243eb8c6244b 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -16,8 +16,9 @@ #include "iree/modules/hal/utils/buffer_diagnostics.h" #include "iree/vm/api.h" -#define IREE_HAL_MODULE_VERSION_0_0 0x00000000u -#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_0 +//===----------------------------------------------------------------------===// +// Limits imposed by the module (and not the HAL) +//===----------------------------------------------------------------------===// // Limit the number of bindings we pass down through the HAL. This can be tuned // in the future but right now guards the stack from blowing up during calls. @@ -35,6 +36,9 @@ // Module type definitions //===----------------------------------------------------------------------===// +#define IREE_HAL_MODULE_VERSION_0_1 0x00000001u +#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_1 + typedef struct iree_hal_module_t { iree_allocator_t host_allocator; iree_hal_module_flags_t flags; @@ -127,19 +131,6 @@ static iree_status_t IREE_API_PTR iree_hal_module_notify( } } -//===----------------------------------------------------------------------===// -// Experimental APIs -//===----------------------------------------------------------------------===// -// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules -// using these APIs are not forward compatible. - -IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, // - iree_hal_module_state_t, // - v, r) { - rets->r0 = iree_hal_device_retain_ref(state->shared_device); - return iree_ok_status(); -} - //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// @@ -159,88 +150,129 @@ static iree_device_size_t iree_hal_cast_device_size(int64_t value) { } //===----------------------------------------------------------------------===// -// iree_hal_allocator_t +// Experimental APIs //===----------------------------------------------------------------------===// +// NOTE: Ex* APIs are experimental and likely to be removed soon. Modules +// using these APIs are not forward compatible. -IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate, // - iree_hal_module_state_t, // - riiI, r) { - iree_hal_allocator_t* allocator = NULL; - IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator)); - iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1; - iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2; - iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i3); - - const iree_hal_buffer_params_t params = { - .type = memory_types, - .usage = buffer_usage, - }; - iree_hal_buffer_t* buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( - allocator, params, allocation_size, iree_const_byte_span_empty(), - &buffer)); - rets->r0 = iree_hal_buffer_move_ref(buffer); +IREE_VM_ABI_EXPORT(iree_hal_module_ex_shared_device, // + iree_hal_module_state_t, // + v, r) { + rets->r0 = iree_hal_device_retain_ref(state->shared_device); return iree_ok_status(); } -IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate_initialized, // - iree_hal_module_state_t, // - riirII, r) { - iree_hal_allocator_t* allocator = NULL; - IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator)); - iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i1; - iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i2; - iree_vm_buffer_t* source = NULL; - IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r3, &source)); - iree_device_size_t offset = iree_hal_cast_device_size(args->i4); - iree_device_size_t length = iree_hal_cast_device_size(args->i5); +static void iree_hal_module_file_buffer_release(void* user_data) { + iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data; + iree_vm_buffer_release(backing_buffer); +} - iree_host_size_t buffer_length = source->data.data_length; - if (length == -1) { - length = buffer_length; +IREE_VM_ABI_EXPORT(iree_hal_module_ex_file_from_memory, // + iree_hal_module_state_t, // + rIirIIi, r) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_memory_access_t access = (iree_hal_memory_access_t)args->i2; + iree_vm_buffer_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r3, &buffer)); + iree_host_size_t offset = iree_hal_cast_host_size(args->i4); + iree_host_size_t length = iree_hal_cast_host_size(args->i5); + uint32_t flags = (uint32_t)args->i6; + + // Only allow read-only access right now while experimental. + // The contents here are almost always from mapped file memory today. + if (iree_any_bit_set(access, ~IREE_HAL_MEMORY_ACCESS_READ)) { + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "only read-only memory can be accessed via a file handle (today)"); } - if (length < 0 || offset < 0 || offset > buffer_length || - offset + length > buffer_length) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "byte range out of bounds (requested %" PRIdsz - "-%" PRIdsz " of available %" PRIhsz ")", - offset, (offset + length - 1), buffer_length); + + // Verify the provided range and get the host pointer. + iree_const_byte_span_t span = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(buffer, offset, length, 1, &span)); + + // Retain the buffer until the file is destroyed. + iree_hal_file_release_callback_t release_callback = { + .fn = iree_hal_module_file_buffer_release, + .user_data = buffer, + }; + iree_vm_buffer_retain(buffer); + + // Attempt to import the memory as a file. + // Memory files are always supported (even if via emulation) so this should + // always succeed. + iree_hal_external_file_t external_file = { + .type = IREE_HAL_EXTERNAL_FILE_TYPE_HOST_ALLOCATION, + .flags = flags, + .handle = + { + .host_allocation = + iree_make_byte_span((void*)span.data, span.data_length), + }, + }; + iree_hal_file_t* file = NULL; + iree_status_t status = iree_hal_file_import( + device, queue_affinity, access, &external_file, release_callback, &file); + if (!iree_status_is_ok(status)) { + iree_vm_buffer_release(buffer); } + rets->r0 = iree_hal_file_move_ref(file); + return status; +} + +//===----------------------------------------------------------------------===// +// iree_hal_allocator_t +//===----------------------------------------------------------------------===// + +IREE_VM_ABI_EXPORT(iree_hal_module_allocator_allocate, // + iree_hal_module_state_t, // + rIiiI, r) { + iree_hal_allocator_t* allocator = NULL; + IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i2; + iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i3; + iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i4); + const iree_hal_buffer_params_t params = { .type = memory_types, .usage = buffer_usage, + .queue_affinity = queue_affinity, }; iree_hal_buffer_t* buffer = NULL; IREE_RETURN_IF_ERROR( - iree_hal_allocator_allocate_buffer( - allocator, params, length, - iree_make_const_byte_span(source->data.data + offset, length), - &buffer), - "failed to allocate buffer of length %" PRIdsz, length); + iree_hal_allocator_allocate_buffer(allocator, params, allocation_size, + iree_const_byte_span_empty(), &buffer), + "failed to allocate buffer of length %" PRIdsz, allocation_size); rets->r0 = iree_hal_buffer_move_ref(buffer); return iree_ok_status(); } -static void iree_hal_module_mapped_buffer_release(void* user_data, - iree_hal_buffer_t* buffer) { +static void iree_hal_module_imported_buffer_release(void* user_data, + iree_hal_buffer_t* buffer) { iree_vm_buffer_t* backing_buffer = (iree_vm_buffer_t*)user_data; iree_vm_buffer_release(backing_buffer); } -IREE_VM_ABI_EXPORT(iree_hal_module_allocator_map_byte_buffer, // - iree_hal_module_state_t, // - riiirII, r) { +IREE_VM_ABI_EXPORT(iree_hal_module_allocator_import, // + iree_hal_module_state_t, // + riIiirII, r) { iree_hal_allocator_t* allocator = NULL; IREE_RETURN_IF_ERROR(iree_hal_allocator_check_deref(args->r0, &allocator)); bool is_try = args->i1 != 0; - iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i2; - iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i3; + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i2; + iree_hal_memory_type_t memory_types = (iree_hal_memory_type_t)args->i3; + iree_hal_buffer_usage_t buffer_usage = (iree_hal_buffer_usage_t)args->i4; iree_vm_buffer_t* source = NULL; - IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r4, &source)); - iree_device_size_t offset = iree_hal_cast_device_size(args->i5); - iree_device_size_t length = iree_hal_cast_device_size(args->i6); + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r5, &source)); + iree_device_size_t offset = iree_hal_cast_device_size(args->i6); + iree_device_size_t length = iree_hal_cast_device_size(args->i7); iree_host_size_t buffer_length = source->data.data_length; if (length == -1) { @@ -261,7 +293,7 @@ IREE_VM_ABI_EXPORT(iree_hal_module_allocator_map_byte_buffer, // IREE_HAL_BUFFER_USAGE_SHARING_IMMUTABLE)) { return iree_make_status(IREE_STATUS_PERMISSION_DENIED, "source buffer is immutable and can only be " - "mapped for constant usage"); + "imported for constant usage"); } // NOTE: if we wanted to lock things down for when there's no MMU to ensure @@ -283,6 +315,7 @@ IREE_VM_ABI_EXPORT(iree_hal_module_allocator_map_byte_buffer, // .type = memory_types, .usage = buffer_usage, .access = allowed_access, + .queue_affinity = queue_affinity, }; iree_hal_external_buffer_t external_buffer = { .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION, @@ -291,21 +324,21 @@ IREE_VM_ABI_EXPORT(iree_hal_module_allocator_map_byte_buffer, // .handle.host_allocation.ptr = source->data.data + offset, }; iree_hal_buffer_release_callback_t release_callback = { - .fn = iree_hal_module_mapped_buffer_release, + .fn = iree_hal_module_imported_buffer_release, .user_data = source, }; iree_hal_buffer_t* buffer = NULL; iree_status_t status = iree_hal_allocator_import_buffer( allocator, params, &external_buffer, release_callback, &buffer); if (iree_status_is_ok(status)) { - // Mapping succeeded - retain the source buffer that'll be released by + // Import succeeded - retain the source buffer that'll be released by // iree_hal_module_map_data_ctl when the mapping is no longer used. iree_vm_buffer_retain(source); rets->r0 = iree_hal_buffer_move_ref(buffer); return iree_ok_status(); } - // Failed to map - if this was a try then don't fail and just rely on the + // Failed to import - if this was a try then don't fail and just rely on the // result being nullptr to indicate to the caller that things failed. memset(&rets->r0, 0, sizeof(rets->r0)); if (is_try) { @@ -998,6 +1031,52 @@ IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_dealloca, // iree_hal_fence_semaphore_list(signal_fence), buffer); } +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_read, // + iree_hal_module_state_t, // + rIrrrIrIIi, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_hal_file_t* source_file = NULL; + IREE_RETURN_IF_ERROR(iree_hal_file_check_deref(args->r4, &source_file)); + uint64_t source_offset = (uint64_t)args->i5; + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r6, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i7); + iree_device_size_t length = iree_hal_cast_device_size(args->i8); + uint32_t flags = (uint32_t)args->i9; + return iree_hal_device_queue_read( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), source_file, source_offset, + target_buffer, target_offset, length, flags); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_write, // + iree_hal_module_state_t, // + rIrrrIrIIi, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_hal_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &source_buffer)); + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i5); + iree_hal_file_t* target_file = NULL; + IREE_RETURN_IF_ERROR(iree_hal_file_check_deref(args->r6, &target_file)); + uint64_t target_offset = (uint64_t)args->i7; + iree_device_size_t length = iree_hal_cast_device_size(args->i8); + uint32_t flags = (uint32_t)args->i9; + return iree_hal_device_queue_write( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), source_buffer, source_offset, + target_file, target_offset, length, flags); +} + IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_execute, // iree_hal_module_state_t, // rIrrCrD, v) { diff --git a/runtime/src/iree/modules/hal/types.c b/runtime/src/iree/modules/hal/types.c index 1ecbc166146d..0c7e0d7900f9 100644 --- a/runtime/src/iree/modules/hal/types.c +++ b/runtime/src/iree/modules/hal/types.c @@ -24,6 +24,7 @@ IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_pipeline_layout, iree_hal_pipeline_layout_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_fence, iree_hal_fence_t); +IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_file, iree_hal_file_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t); //===----------------------------------------------------------------------===// @@ -99,6 +100,9 @@ iree_hal_module_register_all_types(iree_vm_instance_t* instance) { IREE_VM_REGISTER_HAL_C_TYPE(instance, iree_hal_fence_t, "hal.fence", iree_hal_fence_destroy, iree_hal_fence_registration); + IREE_VM_REGISTER_HAL_C_TYPE(instance, iree_hal_file_t, "hal.file", + iree_hal_file_destroy, + iree_hal_file_registration); IREE_VM_REGISTER_HAL_C_TYPE( instance, iree_hal_pipeline_layout_t, "hal.pipeline_layout", iree_hal_pipeline_layout_destroy, iree_hal_pipeline_layout_registration); @@ -171,6 +175,8 @@ iree_hal_module_resolve_all_types(iree_vm_instance_t* instance) { iree_hal_event_registration); IREE_VM_RESOLVE_HAL_C_TYPE(instance, iree_hal_fence_t, "hal.fence", iree_hal_fence_registration); + IREE_VM_RESOLVE_HAL_C_TYPE(instance, iree_hal_file_t, "hal.file", + iree_hal_file_registration); IREE_VM_RESOLVE_HAL_C_TYPE(instance, iree_hal_pipeline_layout_t, "hal.pipeline_layout", iree_hal_pipeline_layout_registration); diff --git a/runtime/src/iree/modules/hal/types.h b/runtime/src/iree/modules/hal/types.h index 601c284d3bf7..f54690126d1f 100644 --- a/runtime/src/iree/modules/hal/types.h +++ b/runtime/src/iree/modules/hal/types.h @@ -27,6 +27,7 @@ IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_executable_cache, iree_hal_executable_cache_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_fence, iree_hal_fence_t); +IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_file, iree_hal_file_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_pipeline_layout, iree_hal_pipeline_layout_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t); diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index bc11d71767e0..c7f18a909b55 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -33,6 +33,7 @@ IREE_VM_ABI_DEFINE_SHIM(riCiiiD, r); IREE_VM_ABI_DEFINE_SHIM(riCrD, r); IREE_VM_ABI_DEFINE_SHIM(rIi, i); IREE_VM_ABI_DEFINE_SHIM(rIirrii, r); +IREE_VM_ABI_DEFINE_SHIM(rIirIIi, r); IREE_VM_ABI_DEFINE_SHIM(rii, r); IREE_VM_ABI_DEFINE_SHIM(rII, r); IREE_VM_ABI_DEFINE_SHIM(rii, v); @@ -40,8 +41,8 @@ IREE_VM_ABI_DEFINE_SHIM(rif, v); IREE_VM_ABI_DEFINE_SHIM(riii, r); IREE_VM_ABI_DEFINE_SHIM(riiI, r); IREE_VM_ABI_DEFINE_SHIM(riii, v); -IREE_VM_ABI_DEFINE_SHIM(riirII, r); -IREE_VM_ABI_DEFINE_SHIM(riiirII, r); +IREE_VM_ABI_DEFINE_SHIM(rIiiI, r); +IREE_VM_ABI_DEFINE_SHIM(riIiirII, r); IREE_VM_ABI_DEFINE_SHIM(rriirIIrIII, v); IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r); IREE_VM_ABI_DEFINE_SHIM(ririi, v); @@ -64,6 +65,7 @@ IREE_VM_ABI_DEFINE_SHIM(rrIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrIii, v); IREE_VM_ABI_DEFINE_SHIM(rrrIii, v); IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DEFINE_SHIM(rIrrrIrIIi, v); IREE_VM_ABI_DEFINE_SHIM(rIrrr, v); IREE_VM_ABI_DEFINE_SHIM(rIrrCrD, v); IREE_VM_ABI_DEFINE_SHIM(CrID, r); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index 7109935d727a..2a1e60c2aa03 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -285,6 +285,16 @@ IREE_VM_ABI_FIXED_STRUCT(rIirrii, { int32_t i6; }); +IREE_VM_ABI_FIXED_STRUCT(rIirIIi, { + iree_vm_ref_t r0; + int64_t i1; + int32_t i2; + iree_vm_ref_t r3; + int64_t i4; + int64_t i5; + int32_t i6; +}); + IREE_VM_ABI_FIXED_STRUCT(rII, { iree_vm_ref_t r0; int64_t i1; @@ -327,23 +337,23 @@ IREE_VM_ABI_FIXED_STRUCT(iirII, { int64_t i4; }); -IREE_VM_ABI_FIXED_STRUCT(riirII, { +IREE_VM_ABI_FIXED_STRUCT(rIiiI, { iree_vm_ref_t r0; - int32_t i1; + int64_t i1; int32_t i2; - iree_vm_ref_t r3; + int32_t i3; int64_t i4; - int64_t i5; }); -IREE_VM_ABI_FIXED_STRUCT(riiirII, { +IREE_VM_ABI_FIXED_STRUCT(riIiirII, { iree_vm_ref_t r0; int32_t i1; - int32_t i2; + int64_t i2; int32_t i3; - iree_vm_ref_t r4; - int64_t i5; + int32_t i4; + iree_vm_ref_t r5; int64_t i6; + int64_t i7; }); IREE_VM_ABI_FIXED_STRUCT(rriirIIrIII, { @@ -423,6 +433,19 @@ IREE_VM_ABI_FIXED_STRUCT(rIrriiiI, { int64_t i7; }); +IREE_VM_ABI_FIXED_STRUCT(rIrrrIrIIi, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + iree_vm_ref_t r4; + int64_t i5; + iree_vm_ref_t r6; + int64_t i7; + int64_t i8; + int32_t i9; +}); + IREE_VM_ABI_FIXED_STRUCT(rIrrr, { iree_vm_ref_t r0; int64_t i1; @@ -600,6 +623,7 @@ IREE_VM_ABI_DECLARE_SHIM(riCiiiD, r); IREE_VM_ABI_DECLARE_SHIM(riCrD, r); IREE_VM_ABI_DECLARE_SHIM(rIi, i); IREE_VM_ABI_DECLARE_SHIM(rIirrii, r); +IREE_VM_ABI_DECLARE_SHIM(rIirIIi, r); IREE_VM_ABI_DECLARE_SHIM(rii, r); IREE_VM_ABI_DECLARE_SHIM(rII, r); IREE_VM_ABI_DECLARE_SHIM(rii, v); @@ -607,8 +631,8 @@ IREE_VM_ABI_DECLARE_SHIM(rif, v); IREE_VM_ABI_DECLARE_SHIM(riii, r); IREE_VM_ABI_DECLARE_SHIM(riiI, r); IREE_VM_ABI_DECLARE_SHIM(riii, v); -IREE_VM_ABI_DECLARE_SHIM(riirII, r); -IREE_VM_ABI_DECLARE_SHIM(riiirII, r); +IREE_VM_ABI_DECLARE_SHIM(rIiiI, r); +IREE_VM_ABI_DECLARE_SHIM(riIiirII, r); IREE_VM_ABI_DECLARE_SHIM(rriirIIrIII, v); IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r); IREE_VM_ABI_DECLARE_SHIM(ririi, v); @@ -631,6 +655,7 @@ IREE_VM_ABI_DECLARE_SHIM(rrIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrIii, v); IREE_VM_ABI_DECLARE_SHIM(rrrIii, v); IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DECLARE_SHIM(rIrrrIrIIi, v); IREE_VM_ABI_DECLARE_SHIM(rIrrr, v); IREE_VM_ABI_DECLARE_SHIM(rIrrCrD, v); IREE_VM_ABI_DECLARE_SHIM(CrID, r); diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c index 1f4b0e33a19d..8ef99cadd1cc 100644 --- a/tools/iree-e2e-matmul-test.c +++ b/tools/iree-e2e-matmul-test.c @@ -196,8 +196,8 @@ static iree_status_t map_host_local_row_major_data( "buffer_view is not dense row major"); } IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( - iree_hal_buffer_view_buffer(buffer_view), - IREE_HAL_MAPPING_MODE_PERSISTENT, access, 0, IREE_WHOLE_BUFFER, mapping)); + iree_hal_buffer_view_buffer(buffer_view), IREE_HAL_MAPPING_MODE_SCOPED, + access, 0, IREE_WHOLE_BUFFER, mapping)); return iree_ok_status(); }