Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 88 additions & 40 deletions unified-runtime/source/adapters/level_zero/enqueued_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,71 @@

#include <ur_api.h>

EnqueuedPool::~EnqueuedPool() { cleanup(); }
namespace {

std::optional<EnqueuedPool::Allocation>
EnqueuedPool::getBestFit(size_t Size, size_t Alignment, void *Queue) {
auto Lock = std::lock_guard(Mutex);
getBestFitHelper(size_t Size, size_t Alignment, void *Queue,
EnqueuedPool::AllocationGroupMap &Freelist) {
// Iterate over the alignments for a given queue.
auto GroupIt = Freelist.lower_bound({Queue, Alignment});
for (; GroupIt != Freelist.end() && GroupIt->first.Queue == Queue;
++GroupIt) {
auto &AllocSet = GroupIt->second;
// Find the first allocation that is large enough.
auto AllocIt = AllocSet.lower_bound({nullptr, Size, nullptr, nullptr, 0});
if (AllocIt != AllocSet.end()) {
auto BestFit = *AllocIt;
AllocSet.erase(AllocIt);
if (AllocSet.empty()) {
Freelist.erase(GroupIt);
}
return BestFit;
}
}
return std::nullopt;
}

Allocation Alloc = {nullptr, Size, nullptr, Queue, Alignment};
void removeFromFreelist(const EnqueuedPool::Allocation &Alloc,
EnqueuedPool::AllocationGroupMap &Freelist,
bool IsGlobal) {
const EnqueuedPool::AllocationGroupKey Key = {
IsGlobal ? nullptr : Alloc.Queue, Alloc.Alignment};

auto It = Freelist.lower_bound(Alloc);
if (It != Freelist.end() && It->Size >= Size && It->Queue == Queue &&
It->Alignment >= Alignment) {
Allocation BestFit = *It;
Freelist.erase(It);
auto GroupIt = Freelist.find(Key);
assert(GroupIt != Freelist.end() && "Allocation group not found in freelist");

return BestFit;
auto &AllocSet = GroupIt->second;
auto AllocIt = AllocSet.find(Alloc);
assert(AllocIt != AllocSet.end() && "Allocation not found in group");

AllocSet.erase(AllocIt);
if (AllocSet.empty()) {
Freelist.erase(GroupIt);
}
}

// To make sure there's no match on other queues, we need to reset it to
// nullptr and try again.
Alloc.Queue = nullptr;
It = Freelist.lower_bound(Alloc);
} // namespace

if (It != Freelist.end() && It->Size >= Size && It->Alignment >= Alignment) {
Allocation BestFit = *It;
Freelist.erase(It);
EnqueuedPool::~EnqueuedPool() { cleanup(); }

std::optional<EnqueuedPool::Allocation>
EnqueuedPool::getBestFit(size_t Size, size_t Alignment, void *Queue) {
auto Lock = std::lock_guard(Mutex);

// First, try to find the best fit in the queue-specific freelist.
auto BestFit = getBestFitHelper(Size, Alignment, Queue, FreelistByQueue);
if (BestFit) {
// Remove the allocation from the global freelist as well.
removeFromFreelist(*BestFit, FreelistGlobal, true);
return BestFit;
}

// If no fit was found in the queue-specific freelist, try the global
// freelist.
BestFit = getBestFitHelper(Size, Alignment, nullptr, FreelistGlobal);
if (BestFit) {
// Remove the allocation from the queue-specific freelist.
removeFromFreelist(*BestFit, FreelistByQueue, false);
return BestFit;
}

Expand All @@ -52,45 +91,54 @@ void EnqueuedPool::insert(void *Ptr, size_t Size, ur_event_handle_t Event,
uintptr_t Address = (uintptr_t)Ptr;
size_t Alignment = Address & (~Address + 1);

Freelist.emplace(Allocation{Ptr, Size, Event, Queue, Alignment});
Allocation Alloc = {Ptr, Size, Event, Queue, Alignment};
FreelistByQueue[{Queue, Alignment}].emplace(Alloc);
FreelistGlobal[{nullptr, Alignment}].emplace(Alloc);
}

bool EnqueuedPool::cleanup() {
auto Lock = std::lock_guard(Mutex);
auto FreedAllocations = !Freelist.empty();
auto FreedAllocations = !FreelistGlobal.empty();

auto Ret [[maybe_unused]] = UR_RESULT_SUCCESS;
for (auto It : Freelist) {
Ret = MemFreeFn(It.Ptr);
assert(Ret == UR_RESULT_SUCCESS);

if (It.Event)
EventReleaseFn(It.Event);
for (const auto &[GroupKey, AllocSet] : FreelistGlobal) {
for (const auto &Alloc : AllocSet) {
Ret = MemFreeFn(Alloc.Ptr);
assert(Ret == UR_RESULT_SUCCESS);

if (Alloc.Event) {
EventReleaseFn(Alloc.Event);
}
}
}
Freelist.clear();

FreelistGlobal.clear();
FreelistByQueue.clear();

return FreedAllocations;
}

bool EnqueuedPool::cleanupForQueue(void *Queue) {
auto Lock = std::lock_guard(Mutex);

Allocation Alloc = {nullptr, 0, nullptr, Queue, 0};
// first allocation on the freelist with the specific queue
auto It = Freelist.lower_bound(Alloc);

bool FreedAllocations = false;

auto Ret [[maybe_unused]] = UR_RESULT_SUCCESS;
while (It != Freelist.end() && It->Queue == Queue) {
Ret = MemFreeFn(It->Ptr);
assert(Ret == UR_RESULT_SUCCESS);

if (It->Event)
EventReleaseFn(It->Event);

// Erase the current allocation and move to the next one
It = Freelist.erase(It);
auto GroupIt = FreelistByQueue.lower_bound({Queue, 0});
while (GroupIt != FreelistByQueue.end() && GroupIt->first.Queue == Queue) {
auto &AllocSet = GroupIt->second;
for (const auto &Alloc : AllocSet) {
Ret = MemFreeFn(Alloc.Ptr);
assert(Ret == UR_RESULT_SUCCESS);

if (Alloc.Event) {
EventReleaseFn(Alloc.Event);
}

removeFromFreelist(Alloc, FreelistGlobal, true);
}

// Move to the next group.
GroupIt = FreelistByQueue.erase(GroupIt);
FreedAllocations = true;
}

Expand Down
42 changes: 31 additions & 11 deletions unified-runtime/source/adapters/level_zero/enqueued_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "ur_api.h"
#include "ur_pool_manager.hpp"
#include <map>
#include <set>
#include <umf_helpers.hpp>

Expand Down Expand Up @@ -43,25 +44,44 @@ class EnqueuedPool {
bool cleanup();
bool cleanupForQueue(void *Queue);

private:
struct Comparator {
bool operator()(const Allocation &lhs, const Allocation &rhs) const {
// Allocations are grouped by queue and alignment.
struct AllocationGroupKey {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be consistent, either use a custom comparator or implement operator< in the key for both types.

void *Queue;
size_t Alignment;
};

struct GroupComparator {
bool operator()(const AllocationGroupKey &lhs,
const AllocationGroupKey &rhs) const {
if (lhs.Queue != rhs.Queue) {
return lhs.Queue < rhs.Queue; // Compare by queue handle first
}
if (lhs.Alignment != rhs.Alignment) {
return lhs.Alignment < rhs.Alignment; // Then by alignment
return lhs.Queue < rhs.Queue;
}
return lhs.Alignment < rhs.Alignment;
}
};

// Then, the allocations are sorted by size.
struct SizeComparator {
bool operator()(const Allocation &lhs, const Allocation &rhs) const {
if (lhs.Size != rhs.Size) {
return lhs.Size < rhs.Size; // Then by size
return lhs.Size < rhs.Size;
}
return lhs.Ptr < rhs.Ptr; // Finally by pointer address
return lhs.Ptr < rhs.Ptr;
}
};

using AllocationSet = std::set<Allocation, Comparator>;
using AllocationGroup = std::set<Allocation, SizeComparator>;
using AllocationGroupMap =
std::map<AllocationGroupKey, AllocationGroup, GroupComparator>;

private:
ur_mutex Mutex;
AllocationSet Freelist;

// Freelist grouped by queue and alignment.
AllocationGroupMap FreelistByQueue;
// Freelist grouped by alignment only.
AllocationGroupMap FreelistGlobal;

event_release_callback_t EventReleaseFn;
memory_free_callback_t MemFreeFn;
};
44 changes: 44 additions & 0 deletions unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,47 @@ TEST_P(urL0EnqueueAllocMultiQueueMultiDeviceTest,
ASSERT_NE(freeEvent, nullptr);
}
}

using urL0EnqueueAllocStandaloneTest = uur::urQueueTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE(urL0EnqueueAllocStandaloneTest);

TEST_P(urL0EnqueueAllocStandaloneTest, ReuseFittingAllocation) {
ur_usm_pool_handle_t pool = nullptr;
ur_usm_pool_desc_t pool_desc = {};
ASSERT_SUCCESS(urUSMPoolCreate(context, &pool_desc, &pool));

auto makeAllocation = [&](uint32_t alignment, size_t size, void **ptr) {
const ur_usm_device_desc_t usm_device_desc{
UR_STRUCTURE_TYPE_USM_DEVICE_DESC, nullptr,
/* device flags */ 0};

const ur_usm_desc_t usm_desc{UR_STRUCTURE_TYPE_USM_DESC, &usm_device_desc,
UR_USM_ADVICE_FLAG_DEFAULT, alignment};

ASSERT_SUCCESS(
urUSMDeviceAlloc(context, device, &usm_desc, pool, size, ptr));
};

std::array<void *, 4> allocations = {};
makeAllocation(64, 128, &allocations[0]);
makeAllocation(64, 256, &allocations[1]);
makeAllocation(4096, 512, &allocations[2]);
makeAllocation(4096, 8192, &allocations[3]);

ASSERT_SUCCESS(
urEnqueueUSMFreeExp(queue, pool, allocations[0], 0, nullptr, nullptr));
ASSERT_SUCCESS(
urEnqueueUSMFreeExp(queue, pool, allocations[1], 0, nullptr, nullptr));
ASSERT_SUCCESS(
urEnqueueUSMFreeExp(queue, pool, allocations[2], 0, nullptr, nullptr));
ASSERT_SUCCESS(
urEnqueueUSMFreeExp(queue, pool, allocations[3], 0, nullptr, nullptr));

void *ptr = nullptr;
ASSERT_SUCCESS(urEnqueueUSMDeviceAllocExp(queue, pool, 8192, nullptr, 0,
nullptr, &ptr, nullptr));

ASSERT_EQ(ptr, allocations[3]); // Fitting allocation should be reused.
ASSERT_SUCCESS(urEnqueueUSMFreeExp(queue, pool, ptr, 0, nullptr, nullptr));
ASSERT_SUCCESS(urQueueFinish(queue));
}
Loading