Skip to content

[SYCL] Fix adjusted kernel name handling in preview RT build #19582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/source/detail/device_image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
!((getOriginMask() & ImageOriginSYCLBIN) && hasKernelName(Name)))
return nullptr;

std::string AdjustedName = adjustKernelName(Name);
std::string_view AdjustedName = getAdjustedKernelNameStrView(Name);
if (MRTCBinInfo && MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
auto &PM = ProgramManager::getInstance();
for (const std::string &Prefix : MRTCBinInfo->MPrefixes) {
auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName);
auto KID = PM.tryGetSYCLKernelID(Prefix + std::string(AdjustedName));

if (!KID || !has_kernel(*KID))
continue;

auto UrProgram = get_ur_program();
auto [UrKernel, CacheMutex, ArgMask] =
PM.getOrCreateKernel(Context, AdjustedName,
PM.getOrCreateKernel(Context, KernelNameStrT(AdjustedName),
/*PropList=*/{}, UrProgram);
return std::make_shared<kernel_impl>(
std::move(UrKernel), *getSyclObjImpl(Context), shared_from_this(),
Expand All @@ -44,7 +44,7 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
ur_program_handle_t UrProgram = get_ur_program();
detail::adapter_impl &Adapter = getSyclObjImpl(Context)->getAdapter();
Managed<ur_kernel_handle_t> UrKernel{Adapter};
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.data(),
&UrKernel);

const KernelArgMask *ArgMask = nullptr;
Expand Down
63 changes: 43 additions & 20 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ class ManagedDeviceBinaries {
sycl_device_binaries MBinaries;
};

// Using ordered containers for heterogenous lookup.
// TODO change to unordered containers after switching to C++20.
using MangledKernelNameMapT = std::map<std::string, std::string, std::less<>>;
using KernelNameSetT = std::set<std::string, std::less<>>;
using KernelNameToArgMaskMap = std::unordered_map<std::string, KernelArgMask>;
using KernelNameToArgMaskMap =
std::map<std::string, KernelArgMask, std::less<>>;

// Information unique to images compiled at runtime through the
// ext_oneapi_kernel_compiler extension.
Expand Down Expand Up @@ -619,32 +622,21 @@ class device_image_impl
#pragma warning(pop)
#endif

std::string adjustKernelName(std::string_view Name) const {
if (MOrigins & ImageOriginSYCLBIN) {
constexpr std::string_view KernelPrefix = "__sycl_kernel_";
if (Name.size() > KernelPrefix.size() &&
Name.substr(0, KernelPrefix.size()) == KernelPrefix)
return Name.data();
return std::string{KernelPrefix} + Name.data();
}

if (!MRTCBinInfo.has_value())
return Name.data();

if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
auto It = MRTCBinInfo->MMangledKernelNames.find(Name);
if (It != MRTCBinInfo->MMangledKernelNames.end())
return It->second;
}
// Assumes the kernel is contained within this image.
std::string_view getAdjustedKernelNameStrView(std::string_view Name) const {
return getAdjustedKernelNameImpl<std::string_view>(Name);
}

return Name.data();
std::string getAdjustedKernelNameStr(std::string_view Name) const {
return getAdjustedKernelNameImpl<std::string>(Name);
}

bool hasKernelName(std::string_view Name) const {
return (getOriginMask() &
(ImageOriginKernelCompiler | ImageOriginSYCLBIN)) &&
!Name.empty() &&
MKernelNames.find(adjustKernelName(Name)) != MKernelNames.end();
MKernelNames.find(getAdjustedKernelNameStr(Name)) !=
MKernelNames.end();
}

std::shared_ptr<kernel_impl>
Expand Down Expand Up @@ -840,6 +832,37 @@ class device_image_impl
}

private:
template <typename RetT>
RetT getAdjustedKernelNameImpl(std::string_view Name) const {
if (MOrigins & ImageOriginSYCLBIN) {
constexpr std::string_view KernelPrefix = "__sycl_kernel_";
if (Name.size() > KernelPrefix.size() &&
Name.substr(0, KernelPrefix.size()) == KernelPrefix)
return RetT(Name);
std::string AdjustedNameStr =
std::string(KernelPrefix) + std::string(Name);
if constexpr (std::is_same_v<RetT, std::string>) {
return AdjustedNameStr;
} else {
static_assert(std::is_same_v<RetT, std::string_view>);
auto It = MKernelNames.find(AdjustedNameStr);
assert(It != MKernelNames.end() && "Adjusted name not found");
return *It;
}
}

if (!MRTCBinInfo.has_value())
return RetT(Name);

if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
auto It = MRTCBinInfo->MMangledKernelNames.find(Name);
if (It != MRTCBinInfo->MMangledKernelNames.end())
return It->second;
}

return RetT(Name);
}

bool hasRTDeviceBinaryImage() const noexcept {
return std::holds_alternative<const RTDeviceBinaryImage *>(MBinImage) &&
get_bin_image_ref() != nullptr;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ class kernel_bundle_impl
throw sycl::exception(make_error_code(errc::invalid),
"kernel '" + Name + "' not found in kernel_bundle");

return It->adjustKernelName(Name);
return It->getAdjustedKernelNameStr(Name);
}

bool ext_oneapi_has_device_global(const std::string &Name) const {
Expand Down
8 changes: 5 additions & 3 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2816,7 +2816,7 @@ ProgramManager::compile(const DevImgPlainWithDeps &ImgWithDeps,
setSpecializationConstants(InputImpl, Prog, Adapter);

KernelNameSetT KernelNames = InputImpl.getKernelNames();
std::unordered_map<std::string, KernelArgMask> EliminatedKernelArgMasks =
std::map<std::string, KernelArgMask, std::less<>> EliminatedKernelArgMasks =
InputImpl.getEliminatedKernelArgMasks();

std::optional<detail::KernelCompilerBinaryInfo> RTCInfo =
Expand Down Expand Up @@ -3006,7 +3006,8 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
RTCInfoPtrs;
RTCInfoPtrs.reserve(Imgs.size());
KernelNameSetT MergedKernelNames;
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
std::map<std::string, KernelArgMask, std::less<>>
MergedEliminatedKernelArgMasks;
for (const device_image_plain &DevImg : Imgs) {
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
CombinedOrigins |= DevImgImpl.getOriginMask();
Expand Down Expand Up @@ -3088,7 +3089,8 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
RTCInfoPtrs;
RTCInfoPtrs.reserve(DevImgWithDeps.size());
KernelNameSetT MergedKernelNames;
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
std::map<std::string, KernelArgMask, std::less<>>
MergedEliminatedKernelArgMasks;
for (const device_image_plain &DevImg : DevImgWithDeps) {
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
RTCInfoPtrs.emplace_back(&(DevImgImpl.getRTCInfo()));
Expand Down
Loading