Skip to content

[DevSAN] Cache internal queue to avoid repeated create/destroy #19540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
assert((void *)Device != nullptr && "Device cannot be nullptr");

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto CI = getAsanInterceptor()->getContextInfo(Context);
auto &Allocation = Allocations[Device];
ur_result_t URes = UR_RESULT_SUCCESS;
if (!Allocation) {
Expand All @@ -106,9 +107,9 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
}

if (HostPtr) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
InternalQueue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(
getContext()->logger, ERR,
Expand Down Expand Up @@ -147,10 +148,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Copy data from last synced device to host
{
ManagedQueue Queue(Context, LastSyncedDevice.hDevice);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size, 0,
nullptr, nullptr);
InternalQueue, true, HostAllocation, LastSyncedDevice.MemHandle, Size,
0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR,
"Failed to migrate memory buffer data");
Expand All @@ -160,9 +161,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Sync data back to device
{
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostAllocation, Size, 0, nullptr, nullptr);
InternalQueue, true, Allocation, HostAllocation, Size, 0, nullptr,
nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR,
"Failed to migrate memory buffer data");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate(
std::shared_ptr<ContextInfo> CtxInfo =
getAsanInterceptor()->getContextInfo(hContext);
for (const auto &hDevice : CtxInfo->DeviceList) {
ManagedQueue InternalQueue(hContext, hDevice);
ur_queue_handle_t InternalQueue = CtxInfo->getInternalQueue(hDevice);
char *Handle = nullptr;
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,15 @@ ur_result_t AsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
auto ContextInfo = getContextInfo(Context);
auto DeviceInfo = getDeviceInfo(Device);

ManagedQueue InternalQueue(Context, Device);
if (!InternalQueue) {
UR_LOG_L(getContext()->logger, ERR, "Failed to create internal queue");
return UR_RESULT_ERROR_INVALID_QUEUE;
}
ur_queue_handle_t InternalQueue = ContextInfo->getInternalQueue(Device);

UR_CALL(prepareLaunch(ContextInfo, DeviceInfo, InternalQueue, Kernel,
LaunchInfo));

UR_CALL(updateShadowMemory(ContextInfo, DeviceInfo, InternalQueue));

UR_CALL(getContext()->urDdiTable.Queue.pfnFinish(InternalQueue));

return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -467,6 +465,7 @@ ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) {

ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
auto Context = GetContext(Program);
auto CI = getContextInfo(Context);
std::vector<ur_device_handle_t> Devices = GetDevices(Program);

for (auto Device : Devices) {
Expand All @@ -484,11 +483,11 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
assert((MetadataSize % sizeof(SpirKernelInfo) == 0) &&
"SpirKernelMetadata size is not correct");

ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);

std::vector<SpirKernelInfo> SKInfo(NumOfSpirKernel);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, &SKInfo[0], MetadataPtr,
InternalQueue, true, &SKInfo[0], MetadataPtr,
sizeof(SpirKernelInfo) * NumOfSpirKernel, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Can't read the value of <{}>: {}",
Expand All @@ -504,7 +503,7 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
}
std::vector<char> KernelNameV(SKI.Size);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, KernelNameV.data(), (void *)SKI.KernelName,
InternalQueue, true, KernelNameV.data(), (void *)SKI.KernelName,
sizeof(char) * SKI.Size, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Can't read kernel name: {}",
Expand Down Expand Up @@ -537,7 +536,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
assert(ProgramInfo != nullptr && "unregistered program!");

for (auto Device : Devices) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = ContextInfo->getInternalQueue(Device);

size_t MetadataSize;
void *MetadataPtr;
Expand All @@ -554,7 +553,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
"DeviceGlobal metadata size is not correct");
std::vector<DeviceGlobalInfo> GVInfos(NumOfDeviceGlobal);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, &GVInfos[0], MetadataPtr,
InternalQueue, true, &GVInfos[0], MetadataPtr,
sizeof(DeviceGlobalInfo) * NumOfDeviceGlobal, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Device Global[{}] Read Failed: {}",
Expand Down Expand Up @@ -932,6 +931,8 @@ bool ProgramInfo::isKernelInstrumented(ur_kernel_handle_t Kernel) const {
ContextInfo::~ContextInfo() {
Stats.Print(Handle);

InternalQueueMap.clear();

[[maybe_unused]] ur_result_t URes;
if (USMPool) {
URes = getContext()->urDdiTable.USM.pfnPoolRelease(USMPool);
Expand Down Expand Up @@ -971,6 +972,13 @@ ur_usm_pool_handle_t ContextInfo::getUSMPool() {
return USMPool;
}

ur_queue_handle_t ContextInfo::getInternalQueue(ur_device_handle_t Device) {
std::scoped_lock<ur_shared_mutex> Guard(InternalQueueMapMutex);
if (!InternalQueueMap[Device])
InternalQueueMap[Device].emplace(Handle, Device);
return *InternalQueueMap[Device];
}

AsanRuntimeDataWrapper::~AsanRuntimeDataWrapper() {
[[maybe_unused]] ur_result_t Result;
if (Host.LocalArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "asan_statistics.hpp"
#include "sanitizer_common/sanitizer_common.hpp"
#include "sanitizer_common/sanitizer_options.hpp"
#include "sanitizer_common/sanitizer_utils.hpp"
#include "ur_sanitizer_layer.hpp"

#include <memory>
Expand Down Expand Up @@ -143,6 +144,10 @@ struct ContextInfo {
std::vector<ur_device_handle_t> DeviceList;
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;

ur_shared_mutex InternalQueueMapMutex;
std::unordered_map<ur_device_handle_t, std::optional<ManagedQueue>>
InternalQueueMap;

AsanStatsWrapper Stats;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
Expand All @@ -163,6 +168,8 @@ struct ContextInfo {
}

ur_usm_pool_handle_t getUSMPool();

ur_queue_handle_t getInternalQueue(ur_device_handle_t);
};

struct AsanRuntimeDataWrapper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ ur_usm_type_t GetUSMType(ur_context_handle_t Context, const void *MemPtr) {
} // namespace

ManagedQueue::ManagedQueue(ur_context_handle_t Context,
ur_device_handle_t Device) {
ur_device_handle_t Device, bool IsOutOfOrder) {
ur_queue_properties_t Prop{UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, nullptr,
UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE};
[[maybe_unused]] auto Result = getContext()->urDdiTable.Queue.pfnCreate(
Context, Device, nullptr, &Handle);
Context, Device, IsOutOfOrder ? &Prop : nullptr, &Handle);
assert(Result == UR_RESULT_SUCCESS && "Failed to create ManagedQueue");
UR_LOG_L(getContext()->logger, DEBUG, ">>> ManagedQueue {}", (void *)Handle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
namespace ur_sanitizer_layer {

struct ManagedQueue {
ManagedQueue(ur_context_handle_t Context, ur_device_handle_t Device);
ManagedQueue(ur_context_handle_t Context, ur_device_handle_t Device,
bool IsOutOfOrder = false);
~ManagedQueue();

// Disable copy semantics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto &Allocation = Allocations[Device];
auto CI = getTsanInterceptor()->getContextInfo(Context);
ur_result_t URes = UR_RESULT_SUCCESS;
if (!Allocation) {
ur_usm_desc_t USMDesc{};
Expand All @@ -114,7 +115,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
}

if (HostPtr) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
Expand Down Expand Up @@ -155,7 +156,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Copy data from last synced device to host
{
ManagedQueue Queue(Context, LastSyncedDevice.hDevice);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size, 0,
nullptr, nullptr);
Expand All @@ -168,7 +169,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Sync data back to device
{
ManagedQueue Queue(Context, Device);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostAllocation, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ ur_result_t urMemBufferCreate(
std::shared_ptr<ContextInfo> CtxInfo =
getTsanInterceptor()->getContextInfo(hContext);
for (const auto &hDevice : CtxInfo->DeviceList) {
ManagedQueue InternalQueue(hContext, hDevice);
ur_queue_handle_t InternalQueue = CtxInfo->getInternalQueue(hDevice);
char *Handle = nullptr;
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/

#include "tsan_interceptor.hpp"
#include "sanitizer_common/sanitizer_utils.hpp"
#include "tsan_report.hpp"

namespace ur_sanitizer_layer {
Expand Down Expand Up @@ -107,6 +106,13 @@ void ContextInfo::insertAllocInfo(TsanAllocInfo AI) {
AllocInfos.insert(std::move(AI));
}

ur_queue_handle_t ContextInfo::getInternalQueue(ur_device_handle_t Device) {
std::scoped_lock<ur_shared_mutex> Guard(InternalQueueMapMutex);
if (!InternalQueueMap[Device])
InternalQueueMap[Device].emplace(Handle, Device, true);
return *InternalQueueMap[Device];
}

TsanInterceptor::~TsanInterceptor() {
// We must release these objects before releasing adapters, since
// they may use the adapter in their destructor
Expand Down Expand Up @@ -190,7 +196,7 @@ TsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
auto &ProgramInfo = getProgramInfo(Program);

for (auto Device : Devices) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t Queue = ContextInfo->getInternalQueue(Device);

size_t MetadataSize;
void *MetadataPtr;
Expand Down Expand Up @@ -333,16 +339,14 @@ ur_result_t TsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
auto CI = getContextInfo(GetContext(Queue));
auto DI = getDeviceInfo(GetDevice(Queue));

ManagedQueue InternalQueue(CI->Handle, DI->Handle);
if (!InternalQueue) {
UR_LOG_L(getContext()->logger, ERR, "Failed to create internal queue");
return UR_RESULT_ERROR_INVALID_QUEUE;
}
ur_queue_handle_t InternalQueue = CI->getInternalQueue(DI->Handle);

UR_CALL(prepareLaunch(CI, DI, InternalQueue, Kernel, LaunchInfo));

UR_CALL(updateShadowMemory(CI, DI, Kernel, InternalQueue));

UR_CALL(getContext()->urDdiTable.Queue.pfnFinish(InternalQueue));

return UR_RESULT_SUCCESS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "sanitizer_common/sanitizer_allocator.hpp"
#include "sanitizer_common/sanitizer_common.hpp"
#include "sanitizer_common/sanitizer_utils.hpp"
#include "tsan_buffer.hpp"
#include "tsan_libdevice.hpp"
#include "tsan_shadow.hpp"
Expand Down Expand Up @@ -58,13 +59,18 @@ struct ContextInfo {
ur_shared_mutex AllocInfosMutex;
std::set<TsanAllocInfo> AllocInfos;

ur_shared_mutex InternalQueueMapMutex;
std::unordered_map<ur_device_handle_t, std::optional<ManagedQueue>>
InternalQueueMap;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRetain(Context);
assert(Result == UR_RESULT_SUCCESS);
}

~ContextInfo() {
InternalQueueMap.clear();
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);
Expand All @@ -75,6 +81,8 @@ struct ContextInfo {
ContextInfo &operator=(const ContextInfo &) = delete;

void insertAllocInfo(TsanAllocInfo AI);

ur_queue_handle_t getInternalQueue(ur_device_handle_t);
};

struct DeviceGlobalInfo {
Expand Down