Skip to content

Commit

Permalink
[metal] NFC: Make code in buffer fill less branchy
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Jun 14, 2023
1 parent ec93093 commit e8679ad
Showing 1 changed file with 34 additions and 45 deletions.
79 changes: 34 additions & 45 deletions experimental/metal/direct_command_buffer.m
Original file line number Diff line number Diff line change
Expand Up @@ -572,33 +572,26 @@ static bool iree_hal_metal_get_duplicated_single_byte_value(const void* pattern,
return false;
}

// Fills |value| by duplicating the given |pattern| into 4-bytes.
static iree_status_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern,
size_t pattern_length,
uint32_t* value) {
switch (pattern_length) {
case 1: {
uint8_t single_byte = *(uint8_t*)pattern;
*value = (uint32_t)single_byte;
*value |= (*value << 8u);
*value |= (*value << 16u);
return iree_ok_status();
}
case 2: {
uint16_t two_bytes = *(uint16_t*)pattern;
*value = (uint32_t)two_bytes;
*value |= (*value << 16u);
return iree_ok_status();
}
case 4: {
*value = *(uint32_t*)pattern;
return iree_ok_status();
}
// Duplicates the given |pattern| into 4-bytes and returns the value.
static uint32_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern,
size_t pattern_length) {
if (pattern_length == 1) {
uint8_t single_byte = *(uint8_t*)pattern;
uint32_t value = (uint32_t)single_byte;
value |= (value << 8u);
value |= (value << 16u);
return value;
}

default:
break;
if (pattern_length == 2) {
uint16_t two_bytes = *(uint16_t*)pattern;
uint32_t value = (uint32_t)two_bytes;
value |= (value << 16u);
return value;
}
return iree_make_status(IREE_STATUS_INTERNAL, "fill pattern should have 1/2/4 bytes");

IREE_ASSERT(pattern_length == 4);
return *(uint32_t*)pattern;
}

static iree_status_t iree_hal_metal_command_buffer_prepare_fill_buffer(
Expand Down Expand Up @@ -647,42 +640,38 @@ static iree_status_t iree_hal_metal_command_segment_record_fill_buffer(
iree_hal_metal_fill_buffer_segment_t* segment) {
IREE_TRACE_ZONE_BEGIN(z0);

// Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer
// can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into
// fillBuffer:range:value:. Otherwise we need to emulate the support.
uint8_t pattern_1byte = 0u;

// Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a
// multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS."
#if defined(IREE_PLATFORM_MACOS)
bool can_use_metal_api = segment->target_offset % 4 == 0 && segment->length % 4 == 0;
const bool can_use_metal_api = segment->target_offset % 4 == 0 && segment->length % 4 == 0 &&
iree_hal_metal_get_duplicated_single_byte_value(
segment->pattern, segment->pattern_length, &pattern_1byte);
#else
bool can_use_metal_api = true;
const bool can_use_metal_api = iree_hal_metal_get_duplicated_single_byte_value(
segment->pattern, segment->pattern_length, &pattern_1byte);
#endif

// Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer
// can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into
// fillBuffer:range:value:. Otherwise we need to emulate the support.
uint8_t single_byte_value = 0u;
if (can_use_metal_api) {
can_use_metal_api &= iree_hal_metal_get_duplicated_single_byte_value(
segment->pattern, segment->pattern_length, &single_byte_value);
}

if (can_use_metal_api) {
id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer);
[encoder fillBuffer:segment->target_buffer
range:NSMakeRange(segment->target_offset, segment->length)
value:single_byte_value];
value:pattern_1byte];
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}

id<MTLComputeCommandEncoder> compute_encoder =
iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
uint32_t pattern_4byte = 0;
iree_status_t status = iree_hal_metal_duplicate_to_four_byte_value(
segment->pattern, segment->pattern_length, &pattern_4byte);
if (iree_status_is_ok(status)) {
status = iree_hal_metal_builtin_executable_fill_buffer(
command_buffer->builtin_executable, compute_encoder, segment->target_buffer,
segment->target_offset, segment->length, pattern_4byte);
}
uint32_t pattern_4byte =
iree_hal_metal_duplicate_to_four_byte_value(segment->pattern, segment->pattern_length);
iree_status_t status = iree_hal_metal_builtin_executable_fill_buffer(
command_buffer->builtin_executable, compute_encoder, segment->target_buffer,
segment->target_offset, segment->length, pattern_4byte);

IREE_TRACE_ZONE_END(z0);
return status;
Expand Down

0 comments on commit e8679ad

Please sign in to comment.