Skip to content

[NFC][SYCL] Explicit types/more comments around getOrBuild #19556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class KernelProgramCache {
std::pair<std::shared_ptr<ProgramBuildResult>, bool>
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
ProgramCache &ProgCache = LockedCache.get();
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
if (DidInsert) {
It->second = std::make_shared<ProgramBuildResult>(getAdapter());
Expand All @@ -426,7 +426,7 @@ class KernelProgramCache {
bool insertBuiltProgram(const ProgramCacheKeyT &CacheKey,
ur_program_handle_t Program) {
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
ProgramCache &ProgCache = LockedCache.get();
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
if (DidInsert) {
It->second = std::make_shared<ProgramBuildResult>(getAdapter(),
Expand Down Expand Up @@ -491,7 +491,7 @@ class KernelProgramCache {
// Save kernel in fast cache only if the corresponding program is also
// in the cache.
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
ProgramCache &ProgCache = LockedCache.get();
if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) ==
ProgCache.ProgramSizeMap.end())
return;
Expand Down Expand Up @@ -631,7 +631,7 @@ class KernelProgramCache {
while (CurrCacheSize > DesiredCacheSize && !MEvictionList.empty()) {
ProgramCacheKeyT CacheKey = ProgramEvictionList.front();
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
ProgramCache &ProgCache = LockedCache.get();
CurrCacheSize = removeProgramByKey(CacheKey, ProgCache);
// Remove the program from the eviction list.
MEvictionList.popFront();
Expand Down Expand Up @@ -748,15 +748,23 @@ class KernelProgramCache {
///
/// \return a pointer to cached build result, return value must not be
/// nullptr.
///
/// Note that build result might be immediately evicted (if it's bigger than
/// current threshold), so the caller *must* assume (potentially shared)
/// ownership. In other words, `std::shared_ptr` in the return type is
/// unavoidable.
template <errc Errc, typename GetCachedBuildFT, typename BuildFT,
typename EvictFT = void *>
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
EvictFT &&EvictFunc = nullptr) {
auto /* std::shared_ptr<BuildResult> */
getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build,
EvictFT &&EvictFunc = nullptr) {
using BuildState = KernelProgramCache::BuildState;
constexpr size_t MaxAttempts = 2;
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
auto Res = GetCachedBuild();
auto /* std::pair<std::shared_ptr<BuildResult>, bool> */ Res =
GetCachedBuild();
auto &BuildResult = Res.first;
assert(BuildResult != nullptr);
BuildState Expected = BuildState::BS_Initial;
BuildState Desired = BuildState::BS_InProgress;
if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) {
Expand Down Expand Up @@ -825,7 +833,7 @@ class KernelProgramCache {

void removeAllRelatedEntries(uint32_t ImageId) {
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
ProgramCache &ProgCache = LockedCache.get();

auto It = std::find_if(
ProgCache.KeyMap.begin(), ProgCache.KeyMap.end(),
Expand Down
24 changes: 10 additions & 14 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,15 +1016,13 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,
};

auto EvictFunc = [&Cache, &CacheKey](ur_program_handle_t Program,
bool isBuilt) {
return Cache.registerProgramFetch(CacheKey, Program, isBuilt);
bool isBuilt) -> void {
Cache.registerProgramFetch(CacheKey, Program, isBuilt);
};

auto BuildResult =
std::shared_ptr<KernelProgramCache::ProgramBuildResult> BuildResult =
Cache.getOrBuild<errc::build>(GetCachedBuildF, BuildF, EvictFunc);

// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");

ur_program_handle_t ResProgram = BuildResult->Val;

Expand Down Expand Up @@ -1082,8 +1080,6 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,
Adapter.call<UrApiKind::urProgramRetain>(ResProgram);
}
CacheLinkedImages();
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
}
}

Expand Down Expand Up @@ -1155,9 +1151,9 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
Kernel, nullptr, ArgMask, Program, ContextImpl.getAdapter());
}

auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
std::shared_ptr<KernelProgramCache::KernelBuildResult> BuildResult =
Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");
const std::pair<ur_kernel_handle_t, const KernelArgMask *>
&KernelArgMaskPair = BuildResult->Val;
auto ret_val = std::make_shared<FastKernelCacheVal>(
Expand Down Expand Up @@ -3192,9 +3188,9 @@ ProgramManager::getOrCreateKernel(const context &Context,
return make_tuple(Kernel, nullptr, ArgMask);
}

auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
std::shared_ptr<KernelProgramCache::KernelBuildResult> BuildResult =
Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
assert(BuildResult && "getOrBuild isn't supposed to return nullptr!");
// If caching is enabled, one copy of the kernel handle will be
// stored in the cache, and one handle is returned to the
// caller. In that case, we need to increase the ref count of the
Expand Down