Skip to content

Commit

Permalink
[Intel GPU] Add XPU memory-related APIs (#129919)
Browse files Browse the repository at this point in the history
# Motivation
According to pytorch/pytorch#116322, we will help unify the device allocator. So we introduce a simple xpu device allocator only with the key functionality first. And expect to add some memory statistics-related functionality after the unification.
But now, some memory statistic-related APIs listed in pytorch/pytorch#127929 are requested. We need more time to unify the device allocator. In order to facilitate the user experience, we expect to support these memory statistic-related APIs before the unification.

# Additional Context
Fixes: #127929

Pull Request resolved: pytorch/pytorch#129919
Approved by: https://github.com/dvrogozh, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #130923
  • Loading branch information
guangyey authored and pytorchmergebot committed Sep 7, 2024
1 parent 6c1da66 commit b53d97c
Show file tree
Hide file tree
Showing 8 changed files with 495 additions and 20 deletions.
133 changes: 130 additions & 3 deletions c10/xpu/XPUCachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace c10::xpu::XPUCachingAllocator {

using namespace c10::CachingDeviceAllocator;

// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
// all sizes are rounded to at least 512 bytes
Expand Down Expand Up @@ -117,13 +119,15 @@ struct AllocParams {
BlockPool* pool;
size_t alloc_size;
Block* block;
StatTypes stat_types = {};
};

} // anonymous namespace

class DeviceCachingAllocator {
private:
mutable std::recursive_mutex mutex;
DeviceStats stats;
BlockPool large_blocks; // unallocated cached blocks larger than 1 MB
BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller
ska::flat_hash_set<Block*> active_blocks; // allocated or in use by a stream
Expand Down Expand Up @@ -173,6 +177,12 @@ class DeviceCachingAllocator {
active_blocks.erase(block);
bool inserted = pool.blocks.insert(block).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);

StatTypes stat_types = get_stat_types_for_pool(pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.active_bytes[stat_type].decrease(block->size);
stats.requested_bytes[stat_type].decrease(block->requested_size);
});
}

void process_events() {
Expand Down Expand Up @@ -250,6 +260,9 @@ class DeviceCachingAllocator {
return false;
}
p.block = new Block(device, p.queue(), size, p.pool, ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
});
return true;
}

Expand Down Expand Up @@ -281,6 +294,12 @@ class DeviceCachingAllocator {
sycl::free(block->ptr, xpu::get_device_context());
auto* pool = block->pool;
pool->blocks.erase(block);

StatTypes stat_types = get_stat_types_for_pool(*pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(block->size);
});

delete block;
}

Expand Down Expand Up @@ -314,6 +333,14 @@ class DeviceCachingAllocator {
}
}

StatTypes get_stat_types_for_pool(const BlockPool& pool) {
StatTypes stat_types = {};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(
pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
return stat_types;
}

Block* alloc_found_block(
AllocParams params,
size_t orig_size,
Expand Down Expand Up @@ -350,6 +377,12 @@ class DeviceCachingAllocator {
bool inserted = active_blocks.insert(block).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted)

for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
stats.allocated_bytes[stat_type].increase(block->size);
stats.active_bytes[stat_type].increase(block->size);
stats.requested_bytes[stat_type].increase(block->requested_size);
});

return block;
}

Expand All @@ -376,6 +409,7 @@ class DeviceCachingAllocator {
auto& pool = get_pool(size);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, &queue, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);

// First, try to get a block from the existing pool.
bool block_found = get_free_block(params);
Expand All @@ -384,9 +418,32 @@ class DeviceCachingAllocator {
block_found = alloc_block(params) ||
(release_cached_blocks() && alloc_block(params));
}
TORCH_CHECK(
block_found,
"XPU out of memory, please use `empty_cache` to release all unoccupied cached memory.");
if (!block_found) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto reserved_bytes =
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
TORCH_CHECK_WITH(
OutOfMemoryError,
false,
"XPU out of memory. Tried to allocate ",
format_size(alloc_size),
". GPU ",
static_cast<int>(device),
" has a total capacity of ",
format_size(device_total),
". Of the allocated memory ",
format_size(allocated_bytes),
" is allocated by PyTorch, and ",
format_size(reserved_bytes - allocated_bytes),
" is reserved by PyTorch but unallocated.",
" Please use `empty_cache` to release all unoccupied cached memory.");
}
bool split_remainder = should_split(params.block, params.size());
return alloc_found_block(std::move(params), orig_size, split_remainder);
}
Expand All @@ -395,6 +452,11 @@ class DeviceCachingAllocator {
std::scoped_lock<std::recursive_mutex> lock(mutex);
block->allocated = false;

StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.allocated_bytes[stat_type].decrease(block->size);
});

if (!block->stream_uses.empty()) {
insert_events(block);
} else {
Expand All @@ -414,6 +476,35 @@ class DeviceCachingAllocator {
std::scoped_lock<std::recursive_mutex> lock(mutex);
release_cached_blocks();
}

DeviceStats getStats() {
std::scoped_lock<std::recursive_mutex> lock(mutex);
return stats;
}

void resetAccumulatedStats() {
std::scoped_lock<std::recursive_mutex> lock(mutex);

for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
stats.allocated_bytes[statType].reset_accumulated();
stats.reserved_bytes[statType].reset_accumulated();
stats.active_bytes[statType].reset_accumulated();
stats.requested_bytes[statType].reset_accumulated();
}
}

void resetPeakStats() {
std::scoped_lock<std::recursive_mutex> lock(mutex);

for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
stats.allocated_bytes[statType].reset_peak();
stats.reserved_bytes[statType].reset_peak();
stats.active_bytes[statType].reset_peak();
stats.requested_bytes[statType].reset_peak();
}
}
};

void local_raw_delete(void* ptr);
Expand Down Expand Up @@ -547,6 +638,30 @@ class XPUAllocator : public Allocator {
void copy_data(void* dest, const void* src, std::size_t count) const final {
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
}

void assertValidDevice(DeviceIndex device) {
const auto device_num = device_allocators.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}

DeviceStats getDeviceStats(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getStats();
}

void resetPeakStats(DeviceIndex device) {
assertValidDevice(device);
device_allocators[device]->resetPeakStats();
}

void resetAccumulatedStats(DeviceIndex device) {
assertValidDevice(device);
device_allocators[device]->resetAccumulatedStats();
}
};

static XPUAllocator allocator;
Expand All @@ -567,6 +682,18 @@ void emptyCache() {
return allocator.emptyCache();
}

void resetPeakStats(DeviceIndex device) {
return allocator.resetPeakStats(device);
}

void resetAccumulatedStats(DeviceIndex device) {
return allocator.resetAccumulatedStats(device);
}

DeviceStats getDeviceStats(DeviceIndex device) {
return allocator.getDeviceStats(device);
}

void* raw_alloc(size_t size) {
return allocator.raw_alloc(size);
}
Expand Down
9 changes: 8 additions & 1 deletion c10/xpu/XPUCachingAllocator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <c10/core/Allocator.h>
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/xpu/XPUStream.h>

namespace c10::xpu::XPUCachingAllocator {
Expand All @@ -11,6 +11,13 @@ C10_XPU_API void init(DeviceIndex device_count);

C10_XPU_API void emptyCache();

C10_XPU_API void resetPeakStats(DeviceIndex device);

C10_XPU_API void resetAccumulatedStats(DeviceIndex device);

C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
DeviceIndex device);

C10_XPU_API void* raw_alloc(size_t size);

C10_XPU_API void raw_delete(void* ptr);
Expand Down
21 changes: 19 additions & 2 deletions docs/source/xpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ torch.xpu
device
device_count
device_of
empty_cache
get_device_capability
get_device_name
get_device_properties
Expand Down Expand Up @@ -51,7 +50,25 @@ Streams and events
Stream


Memory management
-----------------
.. autosummary::
:toctree: generated
:nosignatures:

empty_cache
max_memory_allocated
max_memory_reserved
memory_allocated
memory_reserved
memory_stats
memory_stats_as_nested_dict
reset_accumulated_memory_stats
reset_peak_memory_stats


.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.xpu.memory
.. py:module:: torch.xpu.random
.. py:module:: torch.xpu.streams
.. py:module:: torch.xpu.streams
38 changes: 38 additions & 0 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,44 @@ def test_serialization_array_with_empty(self):
self.assertIs(type(copy), type(original))
self.assertEqual(copy.get_device(), original.get_device())

def test_out_of_memory(self):
tensor = torch.zeros(1024, device="xpu")

with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="xpu")

with self.assertRaisesRegex(RuntimeError, "XPU out of memory."):
torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu")

def test_raises_oom(self):
torch.xpu.memory.empty_cache()
with self.assertRaises(torch.OutOfMemoryError):
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")

def test_memory_allocation(self):
torch.xpu.empty_cache()
prev = torch.xpu.memory_allocated()
a = torch.ones(10, device="xpu")
self.assertGreater(torch.xpu.memory_allocated(), prev)
self.assertGreater(torch.xpu.memory_reserved(), 0)
del a
self.assertEqual(torch.xpu.memory_allocated(), prev)
torch.xpu.empty_cache()
self.assertEqual(torch.xpu.memory_reserved(), 0)

@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_device_memory_allocated(self):
device_count = torch.xpu.device_count()
current_alloc = [torch.xpu.memory_allocated(idx) for idx in range(device_count)]
x = torch.ones(10, device="xpu:0")
self.assertGreater(torch.xpu.memory_allocated(0), current_alloc[0])
self.assertTrue(
all(
torch.xpu.memory_allocated(idx) == current_alloc[idx]
for idx in range(1, device_count)
)
)


instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)

Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -2108,6 +2108,9 @@ def _xpu_getCurrentStream(device: _int) -> Tuple: ...
def _xpu_getCurrentRawStream(device: _int) -> _int: ...
def _xpu_synchronize(device: _int) -> None: ...
def _xpu_emptyCache() -> None: ...
def _xpu_memoryStats(device: _int) -> Dict[str, Any]: ...
def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _xpu_resetPeakMemoryStats(device: _int) -> None: ...

class _XpuDeviceProperties:
name: str
Expand Down
Loading

0 comments on commit b53d97c

Please sign in to comment.