Skip to content

Commit

Permalink
Adding iree_hal_channel_t and the iree_hal_command_buffer_collective …
Browse files Browse the repository at this point in the history
…API.
  • Loading branch information
benvanik committed Dec 6, 2022
1 parent 044017f commit 9a1ab32
Show file tree
Hide file tree
Showing 31 changed files with 1,408 additions and 12 deletions.
10 changes: 10 additions & 0 deletions experimental/rocm/direct_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_copy_buffer(
return iree_ok_status();
}

static iree_status_t iree_hal_rocm_direct_command_buffer_collective(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
iree_hal_collective_op_t op, uint32_t param,
iree_hal_buffer_binding_t send_binding,
iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}

static iree_status_t iree_hal_rocm_direct_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
Expand Down Expand Up @@ -398,6 +407,7 @@ static const iree_hal_command_buffer_vtable_t
.fill_buffer = iree_hal_rocm_direct_command_buffer_fill_buffer,
.update_buffer = iree_hal_rocm_direct_command_buffer_update_buffer,
.copy_buffer = iree_hal_rocm_direct_command_buffer_copy_buffer,
.collective = iree_hal_rocm_direct_command_buffer_collective,
.push_constants = iree_hal_rocm_direct_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_rocm_direct_command_buffer_push_descriptor_set,
Expand Down
12 changes: 10 additions & 2 deletions experimental/rocm/rocm_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ static iree_status_t iree_hal_rocm_device_trim(iree_hal_device_t* base_device) {
return iree_hal_allocator_trim(device->device_allocator);
}

static iree_status_t iree_hal_rocm_device_create_channel(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"collectives not implemented");
}

static iree_status_t iree_hal_rocm_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
Expand Down Expand Up @@ -303,14 +310,14 @@ static iree_status_t iree_hal_rocm_device_wait_semaphores(
}

static iree_status_t iree_hal_rocm_device_profiling_begin(
iree_hal_device_t* device,
iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
return iree_ok_status();
}

static iree_status_t iree_hal_rocm_device_profiling_end(
iree_hal_device_t* device) {
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
Expand All @@ -322,6 +329,7 @@ static const iree_hal_device_vtable_t iree_hal_rocm_device_vtable = {
.device_allocator = iree_hal_rocm_device_allocator,
.trim = iree_hal_rocm_device_trim,
.query_i64 = iree_hal_rocm_device_query_i64,
.create_channel = iree_hal_rocm_device_create_channel,
.create_command_buffer = iree_hal_rocm_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_rocm_device_create_descriptor_set_layout,
Expand Down
2 changes: 2 additions & 0 deletions runtime/iree.natvis
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.allocator&quot;)==0">{(iree_hal_allocator_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.buffer&quot;)==0">{(iree_hal_buffer_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.buffer_view&quot;)==0">{(iree_hal_buffer_view_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.channel&quot;)==0">{(iree_hal_channel_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.command_buffer&quot;)==0">{(iree_hal_command_buffer_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.descriptor_set_layout&quot;)==0">{(iree_hal_descriptor_set_layout_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.device&quot;)==0">{(iree_hal_device_t*)ptr}</DisplayString>
Expand All @@ -609,6 +610,7 @@
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.allocator&quot;)==0">(iree_hal_allocator_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.buffer&quot;)==0">(iree_hal_buffer_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.buffer_view&quot;)==0">(iree_hal_buffer_view_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.channel&quot;)==0">(iree_hal_channel_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.command_buffer&quot;)==0">(iree_hal_command_buffer_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.descriptor_set_layout&quot;)==0">(iree_hal_descriptor_set_layout_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 &amp;&amp; strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, &quot;hal.device&quot;)==0">(iree_hal_device_t*)ptr</ExpandedItem>
Expand Down
2 changes: 2 additions & 0 deletions runtime/src/iree/hal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ iree_runtime_cc_library(
"buffer_view.h",
"buffer_view_util.c",
"buffer_view_util.h",
"channel.c",
"channel.h",
"command_buffer.c",
"command_buffer.h",
"command_buffer_validation.c",
Expand Down
2 changes: 2 additions & 0 deletions runtime/src/iree/hal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ iree_cc_library(
"buffer_view.h"
"buffer_view_util.c"
"buffer_view_util.h"
"channel.c"
"channel.h"
"command_buffer.c"
"command_buffer.h"
"command_buffer_validation.c"
Expand Down
1 change: 1 addition & 0 deletions runtime/src/iree/hal/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/hal/buffer.h" // IWYU pragma: export
#include "iree/hal/buffer_view.h" // IWYU pragma: export
#include "iree/hal/buffer_view_util.h" // IWYU pragma: export
#include "iree/hal/channel.h" // IWYU pragma: export
#include "iree/hal/command_buffer.h" // IWYU pragma: export
#include "iree/hal/device.h" // IWYU pragma: export
#include "iree/hal/driver.h" // IWYU pragma: export
Expand Down
57 changes: 57 additions & 0 deletions runtime/src/iree/hal/channel.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/hal/channel.h"

#include <stddef.h>

#include "iree/base/tracing.h"
#include "iree/hal/detail.h"
#include "iree/hal/device.h"
#include "iree/hal/resource.h"

#define _VTABLE_DISPATCH(channel, method_name) \
IREE_HAL_VTABLE_DISPATCH(channel, iree_hal_channel, method_name)

IREE_HAL_API_RETAIN_RELEASE(channel);

IREE_API_EXPORT iree_status_t iree_hal_channel_create(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_channel);
*out_channel = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_channel)(
device, queue_affinity, params, out_channel);
IREE_TRACE_ZONE_END(z0);
return status;
}

IREE_API_EXPORT void iree_hal_channel_query_rank_and_count(
const iree_hal_channel_t* channel, int32_t* out_rank, int32_t* out_count) {
IREE_ASSERT_ARGUMENT(channel);
int32_t rank = 0;
int32_t count = 0;
_VTABLE_DISPATCH(channel, query_rank_and_count)(channel, &rank, &count);
if (out_rank) *out_rank = rank;
if (out_count) *out_count = count;
}

IREE_API_EXPORT int32_t
iree_hal_channel_rank(const iree_hal_channel_t* channel) {
int32_t rank = 0;
iree_hal_channel_query_rank_and_count(channel, &rank, NULL);
return rank;
}

IREE_API_EXPORT int32_t
iree_hal_channel_count(const iree_hal_channel_t* channel) {
int32_t count = 0;
iree_hal_channel_query_rank_and_count(channel, NULL, &count);
return count;
}
112 changes: 112 additions & 0 deletions runtime/src/iree/hal/channel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_HAL_CHANNEL_H_
#define IREE_HAL_CHANNEL_H_

#include <stdbool.h>
#include <stdint.h>

#include "iree/base/api.h"
#include "iree/hal/allocator.h"
#include "iree/hal/resource.h"

#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

typedef struct iree_hal_device_t iree_hal_device_t;

//===----------------------------------------------------------------------===//
// iree_hal_channel_t
//===----------------------------------------------------------------------===//

enum iree_hal_channel_flag_bits_t {
IREE_HAL_CHANNEL_FLAG_NONE = 0u,
};
typedef uint32_t iree_hal_channel_flags_t;

// Specifies that the channel should use environment settings if available.
#define IREE_HAL_CHANNEL_RANK_DEFAULT ((int32_t)-1)
#define IREE_HAL_CHANNEL_COUNT_DEFAULT ((int32_t)-1)

// Parameters defining how a channel should be configured.
typedef struct {
// Flags controlling channel behavior.
iree_hal_channel_flags_t flags;
// Implementation-defined identifier for the channel.
// May be empty to indicate that the environment should be used to populate
// the identifier.
//
// Equivalent to:
// ncclUniqueId
iree_const_byte_span_t id;
// Rank of the participant within the collective group.
// May be IREE_HAL_CHANNEL_RANK_DEFAULT to indicate that the environment
// should be used to populate the rank.
int32_t rank;
// Total number of participants within the collective group.
// May be IREE_HAL_CHANNEL_COUNT_DEFAULT to indicate that the environment
// should be used to populate the count.
int32_t count;
} iree_hal_channel_params_t;

// A collective communication channel representing a single rank.
//
// Equivalent to:
// MPI_Comm
// ncclComm_t
// ccl::communicator
typedef struct iree_hal_channel_t iree_hal_channel_t;

// Creates a channel on |device| for use by all queues defined in
// |queue_affinity|. |params| may specify the channel parameters or leave its
// fields as default to indicate that the value should be sourced from the
// environment.
IREE_API_EXPORT iree_status_t iree_hal_channel_create(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
iree_hal_channel_params_t params, iree_hal_channel_t** out_channel);

// Retains the given |channel| for the caller.
IREE_API_EXPORT void iree_hal_channel_retain(iree_hal_channel_t* channel);

// Releases the given |channel| from the caller.
IREE_API_EXPORT void iree_hal_channel_release(iree_hal_channel_t* channel);

// Returns the rank the channel represents as a participant in a collective
// group in `[0, count)` and the total participant count.
IREE_API_EXPORT void iree_hal_channel_query_rank_and_count(
const iree_hal_channel_t* channel, int32_t* out_rank, int32_t* out_count);

// Returns the rank the channel represents as a participant in a collective
// group in `[0, count)`.
IREE_API_EXPORT int32_t
iree_hal_channel_rank(const iree_hal_channel_t* channel);

// Returns the total participant count in a collective group.
IREE_API_EXPORT int32_t
iree_hal_channel_count(const iree_hal_channel_t* channel);

//===----------------------------------------------------------------------===//
// iree_hal_channel_t implementation details
//===----------------------------------------------------------------------===//

typedef struct iree_hal_channel_vtable_t {
void(IREE_API_PTR* destroy)(iree_hal_channel_t* channel);

void(IREE_API_PTR* query_rank_and_count)(const iree_hal_channel_t* channel,
int32_t* out_rank,
int32_t* out_count);
} iree_hal_channel_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_channel_vtable_t);

IREE_API_EXPORT void iree_hal_channel_destroy(iree_hal_channel_t* channel);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif // IREE_HAL_CHANNEL_H_
21 changes: 21 additions & 0 deletions runtime/src/iree/hal/command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,27 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
return status;
}

IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective(
iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
iree_hal_collective_op_t op, uint32_t param,
iree_hal_buffer_binding_t send_binding,
iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(channel);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_collective_validation(
command_buffer, VALIDATION_STATE(command_buffer), channel, op,
param, send_binding, recv_binding, element_count));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, collective)(
command_buffer, channel, op, param, send_binding, recv_binding,
element_count);
IREE_TRACE_ZONE_END(z0);
return status;
}

IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_constants(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
Expand Down
Loading

0 comments on commit 9a1ab32

Please sign in to comment.