diff --git a/sycl/include/sycl/ext/oneapi/experimental/current_device.hpp b/sycl/include/sycl/ext/oneapi/experimental/current_device.hpp index a814728c57f16..7dc41bfc0970b 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/current_device.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/current_device.hpp @@ -15,8 +15,15 @@ inline namespace _V1 { namespace ext::oneapi::experimental::this_thread { namespace detail { -inline sycl::device &get_current_device_ref() { - static thread_local sycl::device current_device{sycl::default_selector_v}; +using namespace sycl::detail; +// Underlying `std::shared_ptr`'s lifetime is tied to the +// `global_handler`, so a subsequent `lock()` is expected to be successful when +// used from user app. We still go through `std::weak_ptr` here because our own +// unittests are linked statically against SYCL RT objects and have to implement +// some hacks to emulate the lifetime management done by the `global_handler`. +inline std::weak_ptr &get_current_device_impl() { + static thread_local std::weak_ptr current_device{ + getSyclObjImpl(sycl::device{sycl::default_selector_v})}; return current_device; } } // namespace detail @@ -28,7 +35,8 @@ inline sycl::device &get_current_device_ref() { /// @pre The function is called from a host thread, executing outside of a host /// task or an asynchronous error handler. inline sycl::device get_current_device() { - return detail::get_current_device_ref(); + return detail::createSyclObjFromImpl( + detail::get_current_device_impl().lock()); } /// @brief Sets the current default device to `dev` for the calling host thread. @@ -36,7 +44,7 @@ inline sycl::device get_current_device() { /// @pre The function is called from a host thread, executing outside of a host /// task or an asynchronous error handler. inline void set_current_device(sycl::device dev) { - detail::get_current_device_ref() = dev; + detail::get_current_device_impl() = detail::getSyclObjImpl(dev); } } // namespace ext::oneapi::experimental::this_thread diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index b6deeaba22d23..5e027466d7949 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -29,17 +29,15 @@ namespace sycl { inline namespace _V1 { namespace detail { -context_impl::context_impl(const std::vector Devices, - async_handler AsyncHandler, +context_impl::context_impl(devices_range Devices, async_handler AsyncHandler, const property_list &PropList, private_tag) : MOwnedByRuntime(true), MAsyncHandler(std::move(AsyncHandler)), - MDevices(std::move(Devices)), MContext(nullptr), - MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())), - MPropList(PropList), MKernelProgramCache(*this), - MSupportBufferLocationByDevices(NotChecked) { + MDevices(Devices.to>()), MContext(nullptr), + MPlatform(MDevices[0]->getPlatformImpl()), MPropList(PropList), + MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) { verifyProps(PropList); std::vector DeviceIds; - for (const auto &D : MDevices) { + for (device_impl &D : devices_range{MDevices}) { if (D.has(aspect::ext_oneapi_is_composite)) { // Component devices are considered to be descendent devices from a // composite device and therefore context created for a composite @@ -52,7 +50,7 @@ context_impl::context_impl(const std::vector Devices, DeviceIds.push_back(getSyclObjImpl(CD)->getHandleRef()); } - DeviceIds.push_back(getSyclObjImpl(D)->getHandleRef()); + DeviceIds.push_back(D.getHandleRef()); } getAdapter().call( @@ -61,39 +59,42 @@ context_impl::context_impl(const std::vector Devices, context_impl::context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, adapter_impl &Adapter, - const std::vector &DeviceList, - bool OwnedByRuntime, private_tag) + devices_range DeviceList, bool OwnedByRuntime, + private_tag) : MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(std::move(AsyncHandler)), - MDevices(DeviceList), MContext(UrContext), MPlatform(), + MDevices([&]() { + if (!DeviceList.empty()) + return DeviceList.to>(); + + std::vector DeviceIds; + uint32_t DevicesNum = 0; + // TODO catch an exception and put it to list of asynchronous + // exceptions. + Adapter.call( + UrContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), + &DevicesNum, nullptr); + DeviceIds.resize(DevicesNum); + // TODO catch an exception and put it to list of asynchronous + // exceptions. + Adapter.call( + UrContext, UR_CONTEXT_INFO_DEVICES, + sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr); + + if (DeviceIds.empty()) + throw exception( + make_error_code(errc::invalid), + "No devices in the provided device list and native context."); + + platform_impl &Platform = + platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter); + std::vector Devices; + for (ur_device_handle_t Dev : DeviceIds) + Devices.emplace_back(&Platform.getOrMakeDeviceImpl(Dev)); + + return Devices; + }()), + MContext(UrContext), MPlatform(MDevices[0]->getPlatformImpl()), MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) { - if (!MDevices.empty()) { - MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()); - } else { - std::vector DeviceIds; - uint32_t DevicesNum = 0; - // TODO catch an exception and put it to list of asynchronous exceptions - Adapter.call( - MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum, - nullptr); - DeviceIds.resize(DevicesNum); - // TODO catch an exception and put it to list of asynchronous exceptions - Adapter.call( - MContext, UR_CONTEXT_INFO_DEVICES, - sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr); - - if (DeviceIds.empty()) - throw exception( - make_error_code(errc::invalid), - "No devices in the provided device list and native context."); - - platform_impl &Platform = - platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter); - for (ur_device_handle_t Dev : DeviceIds) { - MDevices.emplace_back( - createSyclObjFromImpl(Platform.getOrMakeDeviceImpl(Dev))); - } - MPlatform = Platform.shared_from_this(); - } // TODO catch an exception and put it to list of asynchronous exceptions // getAdapter() will be the same as the Adapter passed. This should be taken // care of when creating device object. @@ -144,12 +145,12 @@ uint32_t context_impl::get_info() const { this->getAdapter()); } template <> platform context_impl::get_info() const { - return createSyclObjFromImpl(*MPlatform); + return createSyclObjFromImpl(MPlatform); } template <> std::vector context_impl::get_info() const { - return MDevices; + return devices_range{MDevices}.to>(); } template <> std::vector @@ -219,7 +220,7 @@ context_impl::get_backend_info() const { "the info::platform::version info descriptor can " "only be queried with an OpenCL backend"); } - return MDevices[0].get_platform().get_info(); + return MDevices[0]->get_platform().get_info(); } #endif @@ -271,17 +272,17 @@ KernelProgramCache &context_impl::getKernelProgramCache() const { } bool context_impl::hasDevice(const detail::device_impl &Device) const { - for (auto D : MDevices) - if (getSyclObjImpl(D).get() == &Device) + for (device_impl *D : MDevices) + if (D == &Device) return true; return false; } device_impl * context_impl::findMatchingDeviceImpl(ur_device_handle_t &DeviceUR) const { - for (device D : MDevices) - if (getSyclObjImpl(D)->getHandleRef() == DeviceUR) - return getSyclObjImpl(D).get(); + for (device_impl *D : MDevices) + if (D->getHandleRef() == DeviceUR) + return D; return nullptr; } @@ -301,8 +302,8 @@ bool context_impl::isBufferLocationSupported() const { return MSupportBufferLocationByDevices == Supported ? true : false; // Check that devices within context have support of buffer location MSupportBufferLocationByDevices = Supported; - for (auto &Device : MDevices) { - if (!Device.has_extension("cl_intel_mem_alloc_buffer_location")) { + for (device_impl *Device : MDevices) { + if (!Device->has_extension("cl_intel_mem_alloc_buffer_location")) { MSupportBufferLocationByDevices = NotSupported; break; } diff --git a/sycl/source/detail/context_impl.hpp b/sycl/source/detail/context_impl.hpp index 6d97f1c9ca47e..91d10a235d360 100644 --- a/sycl/source/detail/context_impl.hpp +++ b/sycl/source/detail/context_impl.hpp @@ -47,9 +47,8 @@ class context_impl : public std::enable_shared_from_this { /// \param DeviceList is a list of SYCL device instances. /// \param AsyncHandler is an instance of async_handler. /// \param PropList is an instance of property_list. - context_impl(const std::vector DeviceList, - async_handler AsyncHandler, const property_list &PropList, - private_tag); + context_impl(devices_range DeviceList, async_handler AsyncHandler, + const property_list &PropList, private_tag); /// Construct a context_impl using plug-in interoperability handle. /// @@ -62,9 +61,8 @@ class context_impl : public std::enable_shared_from_this { /// \param OwnedByRuntime is the flag if ownership is kept by user or /// transferred to runtime context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, - adapter_impl &Adapter, - const std::vector &DeviceList, bool OwnedByRuntime, - private_tag); + adapter_impl &Adapter, devices_range DeviceList, + bool OwnedByRuntime, private_tag); context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, adapter_impl &Adapter, private_tag tag) @@ -94,10 +92,10 @@ class context_impl : public std::enable_shared_from_this { const async_handler &get_async_handler() const; /// \return the Adapter associated with the platform of this context. - adapter_impl &getAdapter() const { return MPlatform->getAdapter(); } + adapter_impl &getAdapter() const { return MPlatform.getAdapter(); } /// \return the PlatformImpl associated with this context. - platform_impl &getPlatformImpl() const { return *MPlatform; } + platform_impl &getPlatformImpl() const { return MPlatform; } /// Queries this context for information. /// @@ -191,10 +189,7 @@ class context_impl : public std::enable_shared_from_this { } // Returns the backend of this context - backend getBackend() const { - assert(MPlatform && "MPlatform must be not null"); - return MPlatform->getBackend(); - } + backend getBackend() const { return MPlatform.getBackend(); } /// Given a UR device, returns the matching shared_ptr /// within this context. May return nullptr if no match discovered. @@ -262,10 +257,9 @@ class context_impl : public std::enable_shared_from_this { private: bool MOwnedByRuntime; async_handler MAsyncHandler; - std::vector MDevices; + std::vector MDevices; ur_context_handle_t MContext; - // TODO: Make it a reference instead, but that needs a bit more refactoring: - std::shared_ptr MPlatform; + platform_impl &MPlatform; property_list MPropList; CachedLibProgramsT MCachedLibPrograms; std::mutex MCachedLibProgramsMutex;