Skip to content

Commit

Permalink
[SYCL] Fix pi2ur conversion of bool info queries (intel#9672)
Browse files Browse the repository at this point in the history
Also fix:
* Uninitialized stype values in UR queries
* Missing PI_KERNEL_INFO_NUM_REGS mapping
  • Loading branch information
callumfare authored Jun 12, 2023
1 parent 395aa8a commit 9a4a2f4
Showing 1 changed file with 52 additions and 27 deletions.
79 changes: 52 additions & 27 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,16 @@ inline pi_result ur2piUSMAllocInfoValue(ur_usm_alloc_info_t ParamName,
}

// Handle mismatched PI and UR type return sizes for info queries
inline pi_result fixupInfoValueTypes(size_t ParamValueSizeUR,
inline pi_result fixupInfoValueTypes(size_t ParamValueSizeRetUR,
size_t *ParamValueSizeRetPI,
void *ParamValue) {
if (ParamValueSizeUR == 1) {
size_t ParamValueSize, void *ParamValue) {
if (ParamValueSizeRetUR == 1 && ParamValueSize == 4) {
// extend bool to pi_bool (uint32_t)
auto *ValIn = static_cast<bool *>(ParamValue);
auto *ValOut = static_cast<pi_bool *>(ParamValue);
*ValOut = static_cast<pi_bool>(*ValIn);
if (ParamValue) {
auto *ValIn = static_cast<bool *>(ParamValue);
auto *ValOut = static_cast<pi_bool *>(ParamValue);
*ValOut = static_cast<pi_bool>(*ValIn);
}
if (ParamValueSizeRetPI) {
*ParamValueSizeRetPI = sizeof(pi_bool);
}
Expand Down Expand Up @@ -591,13 +593,18 @@ inline pi_result piPlatformGetInfo(pi_platform Platform,
die("urGetContextInfo: unsuppported ParamName.");
}

size_t SizeInOut = ParamValueSize;
size_t UrParamValueSizeRet;
auto UrPlatform = reinterpret_cast<ur_platform_handle_t>(Platform);
HANDLE_ERRORS(urPlatformGetInfo(UrPlatform, UrParamName, SizeInOut,
ParamValue, ParamValueSizeRet));
HANDLE_ERRORS(urPlatformGetInfo(UrPlatform, UrParamName, ParamValueSize,
ParamValue, &UrParamValueSizeRet));

ur2piPlatformInfoValue(UrParamName, ParamValueSize, &SizeInOut, ParamValue);
fixupInfoValueTypes(SizeInOut, ParamValueSizeRet, ParamValue);
if (ParamValueSizeRet) {
*ParamValueSizeRet = UrParamValueSizeRet;
}
ur2piPlatformInfoValue(UrParamName, ParamValueSize, &ParamValueSize,
ParamValue);
fixupInfoValueTypes(UrParamValueSizeRet, ParamValueSizeRet, ParamValueSize,
ParamValue);

return PI_SUCCESS;
}
Expand Down Expand Up @@ -1016,14 +1023,18 @@ inline pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,

PI_ASSERT(Device, PI_ERROR_INVALID_DEVICE);

size_t SizeInOut = ParamValueSize;
size_t UrParamValueSizeRet;
auto UrDevice = reinterpret_cast<ur_device_handle_t>(Device);

HANDLE_ERRORS(urDeviceGetInfo(UrDevice, InfoType, SizeInOut, ParamValue,
ParamValueSizeRet));
HANDLE_ERRORS(urDeviceGetInfo(UrDevice, InfoType, ParamValueSize, ParamValue,
&UrParamValueSizeRet));

ur2piDeviceInfoValue(InfoType, ParamValueSize, &SizeInOut, ParamValue);
fixupInfoValueTypes(SizeInOut, ParamValueSizeRet, ParamValue);
if (ParamValueSizeRet) {
*ParamValueSizeRet = UrParamValueSizeRet;
}
ur2piDeviceInfoValue(InfoType, ParamValueSize, &ParamValueSize, ParamValue);
fixupInfoValueTypes(UrParamValueSizeRet, ParamValueSizeRet, ParamValueSize,
ParamValue);

return PI_SUCCESS;
}
Expand Down Expand Up @@ -1167,6 +1178,9 @@ piextDeviceSelectBinary(pi_device Device, // TODO: does this need to be context?
__SYCL_PI_DEVICE_BINARY_TARGET_AMDGCN) == 0)
UrBinaries[BinaryCount].pDeviceTargetSpec =
UR_DEVICE_BINARY_TARGET_AMDGCN;
else
UrBinaries[BinaryCount].pDeviceTargetSpec =
UR_DEVICE_BINARY_TARGET_UNKNOWN;
}

HANDLE_ERRORS(urDeviceSelectBinary(UrDevice, UrBinaries.data(), NumBinaries,
Expand Down Expand Up @@ -1286,10 +1300,14 @@ inline pi_result piContextGetInfo(pi_context Context, pi_context_info ParamName,
}
}

size_t UrParamValueSizeRet;
HANDLE_ERRORS(urContextGetInfo(hContext, ContextInfoType, ParamValueSize,
ParamValue, ParamValueSizeRet));
fixupInfoValueTypes(ParamValueSize, ParamValueSizeRet, ParamValue);

ParamValue, &UrParamValueSizeRet));
if (ParamValueSizeRet) {
*ParamValueSizeRet = UrParamValueSizeRet;
}
fixupInfoValueTypes(UrParamValueSizeRet, ParamValueSizeRet, ParamValueSize,
ParamValue);
return PI_SUCCESS;
}

Expand Down Expand Up @@ -2122,9 +2140,10 @@ inline pi_result piKernelGetGroupInfo(pi_kernel Kernel, pi_device Device,
}
// The number of registers used by the compiled kernel (device specific)
case PI_KERNEL_GROUP_INFO_NUM_REGS: {
die("PI_KERNEL_GROUP_INFO_NUM_REGS in piKernelGetGroupInfo not "
"implemented\n");
break;
HANDLE_ERRORS(urKernelGetInfo(UrKernel, UR_KERNEL_INFO_NUM_REGS,
ParamValueSize, ParamValue,
ParamValueSizeRet));
return PI_SUCCESS;
}
default: {
die("Unknown ParamName in piKernelGetGroupInfo");
Expand Down Expand Up @@ -2376,6 +2395,7 @@ inline pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags,
}

ur_buffer_properties_t UrProps{};
UrProps.stype = UR_STRUCTURE_TYPE_BUFFER_PROPERTIES;
UrProps.pHost = HostPtr;
ur_mem_handle_t *UrBuffer = reinterpret_cast<ur_mem_handle_t *>(RetMem);
HANDLE_ERRORS(
Expand Down Expand Up @@ -2560,6 +2580,7 @@ static void pi2urImageDesc(const pi_image_format *ImageFormat,
}
}

UrDesc->stype = UR_STRUCTURE_TYPE_IMAGE_DESC;
UrDesc->arraySize = ImageDesc->image_array_size;
UrDesc->depth = ImageDesc->image_depth;
UrDesc->height = ImageDesc->image_height;
Expand Down Expand Up @@ -3792,6 +3813,7 @@ inline pi_result piSamplerCreate(pi_context Context,
ur_context_handle_t UrContext =
reinterpret_cast<ur_context_handle_t>(Context);
ur_sampler_desc_t UrProps{};
UrProps.stype = UR_STRUCTURE_TYPE_SAMPLER_DESC;
const pi_sampler_properties *CurProperty = SamplerProperties;
while (*CurProperty != 0) {
switch (*CurProperty) {
Expand Down Expand Up @@ -3865,13 +3887,16 @@ inline pi_result piSamplerGetInfo(pi_sampler Sampler, pi_sampler_info ParamName,
return PI_ERROR_UNKNOWN;
}

size_t SizeInOut = ParamValueSize;
size_t UrParamValueSizeRet;
auto hSampler = reinterpret_cast<ur_sampler_handle_t>(Sampler);
HANDLE_ERRORS(urSamplerGetInfo(hSampler, InfoType, SizeInOut, ParamValue,
ParamValueSizeRet));
HANDLE_ERRORS(urSamplerGetInfo(hSampler, InfoType, ParamValueSize, ParamValue,
&UrParamValueSizeRet));
if (ParamValueSizeRet) {
*ParamValueSizeRet = UrParamValueSizeRet;
}
ur2piSamplerInfoValue(InfoType, ParamValueSize, &ParamValueSize, ParamValue);
fixupInfoValueTypes(SizeInOut, ParamValueSizeRet, ParamValue);

fixupInfoValueTypes(UrParamValueSizeRet, ParamValueSizeRet, ParamValueSize,
ParamValue);
return PI_SUCCESS;
}

Expand Down

0 comments on commit 9a4a2f4

Please sign in to comment.