diff --git a/sycl/source/detail/kernel_program_cache.hpp b/sycl/source/detail/kernel_program_cache.hpp index 23dfa3e122c15..8e9fa563f8874 100644 --- a/sycl/source/detail/kernel_program_cache.hpp +++ b/sycl/source/detail/kernel_program_cache.hpp @@ -404,7 +404,7 @@ class KernelProgramCache { std::pair, bool> getOrInsertProgram(const ProgramCacheKeyT &CacheKey) { auto LockedCache = acquireCachedPrograms(); - auto &ProgCache = LockedCache.get(); + ProgramCache &ProgCache = LockedCache.get(); auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr); if (DidInsert) { It->second = std::make_shared(getAdapter()); @@ -426,7 +426,7 @@ class KernelProgramCache { bool insertBuiltProgram(const ProgramCacheKeyT &CacheKey, ur_program_handle_t Program) { auto LockedCache = acquireCachedPrograms(); - auto &ProgCache = LockedCache.get(); + ProgramCache &ProgCache = LockedCache.get(); auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr); if (DidInsert) { It->second = std::make_shared(getAdapter(), @@ -491,7 +491,7 @@ class KernelProgramCache { // Save kernel in fast cache only if the corresponding program is also // in the cache. auto LockedCache = acquireCachedPrograms(); - auto &ProgCache = LockedCache.get(); + ProgramCache &ProgCache = LockedCache.get(); if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) == ProgCache.ProgramSizeMap.end()) return; @@ -631,7 +631,7 @@ class KernelProgramCache { while (CurrCacheSize > DesiredCacheSize && !MEvictionList.empty()) { ProgramCacheKeyT CacheKey = ProgramEvictionList.front(); auto LockedCache = acquireCachedPrograms(); - auto &ProgCache = LockedCache.get(); + ProgramCache &ProgCache = LockedCache.get(); CurrCacheSize = removeProgramByKey(CacheKey, ProgCache); // Remove the program from the eviction list. MEvictionList.popFront(); @@ -748,15 +748,23 @@ class KernelProgramCache { /// /// \return a pointer to cached build result, return value must not be /// nullptr. + /// + /// Note that build result might be immediately evicted (if it's bigger than + /// current threshold), so the caller *must* assume (potentially shared) + /// ownership. In other words, `std::shared_ptr` in the return type is + /// unavoidable. template - auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build, - EvictFT &&EvictFunc = nullptr) { + auto /* std::shared_ptr */ + getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build, + EvictFT &&EvictFunc = nullptr) { using BuildState = KernelProgramCache::BuildState; constexpr size_t MaxAttempts = 2; for (size_t AttemptCounter = 0;; ++AttemptCounter) { - auto Res = GetCachedBuild(); + auto /* std::pair, bool> */ Res = + GetCachedBuild(); auto &BuildResult = Res.first; + assert(BuildResult != nullptr); BuildState Expected = BuildState::BS_Initial; BuildState Desired = BuildState::BS_InProgress; if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) { @@ -825,7 +833,7 @@ class KernelProgramCache { void removeAllRelatedEntries(uint32_t ImageId) { auto LockedCache = acquireCachedPrograms(); - auto &ProgCache = LockedCache.get(); + ProgramCache &ProgCache = LockedCache.get(); auto It = std::find_if( ProgCache.KeyMap.begin(), ProgCache.KeyMap.end(), diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index fb8b36915bb9e..485b94e36f658 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1016,15 +1016,13 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps, }; auto EvictFunc = [&Cache, &CacheKey](ur_program_handle_t Program, - bool isBuilt) { - return Cache.registerProgramFetch(CacheKey, Program, isBuilt); + bool isBuilt) -> void { + Cache.registerProgramFetch(CacheKey, Program, isBuilt); }; - auto BuildResult = + std::shared_ptr BuildResult = Cache.getOrBuild(GetCachedBuildF, BuildF, EvictFunc); - - // getOrBuild is not supposed to return nullptr - assert(BuildResult != nullptr && "Invalid build result"); + assert(BuildResult && "getOrBuild isn't supposed to return nullptr!"); ur_program_handle_t ResProgram = BuildResult->Val; @@ -1082,8 +1080,6 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps, Adapter.call(ResProgram); } CacheLinkedImages(); - // getOrBuild is not supposed to return nullptr - assert(BuildResult != nullptr && "Invalid build result"); } } @@ -1155,9 +1151,9 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel( Kernel, nullptr, ArgMask, Program, ContextImpl.getAdapter()); } - auto BuildResult = Cache.getOrBuild(GetCachedBuildF, BuildF); - // getOrBuild is not supposed to return nullptr - assert(BuildResult != nullptr && "Invalid build result"); + std::shared_ptr BuildResult = + Cache.getOrBuild(GetCachedBuildF, BuildF); + assert(BuildResult && "getOrBuild isn't supposed to return nullptr!"); const std::pair &KernelArgMaskPair = BuildResult->Val; auto ret_val = std::make_shared( @@ -3192,9 +3188,9 @@ ProgramManager::getOrCreateKernel(const context &Context, return make_tuple(Kernel, nullptr, ArgMask); } - auto BuildResult = Cache.getOrBuild(GetCachedBuildF, BuildF); - // getOrBuild is not supposed to return nullptr - assert(BuildResult != nullptr && "Invalid build result"); + std::shared_ptr BuildResult = + Cache.getOrBuild(GetCachedBuildF, BuildF); + assert(BuildResult && "getOrBuild isn't supposed to return nullptr!"); // If caching is enabled, one copy of the kernel handle will be // stored in the cache, and one handle is returned to the // caller. In that case, we need to increase the ref count of the