Skip to content

Commit

Permalink
Uses MemPoolContext to route allocations from CUDACachingAllocator (#…
Browse files Browse the repository at this point in the history
…134685)

Re-open of pytorch/pytorch#133599 that was mistakenly closed by issuing `ghstack land`
Pull Request resolved: pytorch/pytorch#134685
Approved by: https://github.com/ezyang
  • Loading branch information
syed-ahmed authored and pytorchmergebot committed Aug 29, 2024
1 parent 4b4ba7a commit 4655eb3
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 4 deletions.
7 changes: 6 additions & 1 deletion c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2674,7 +2674,12 @@ class DeviceCachingAllocator {
// any potential exceptions in the cudaMallocMaybeCapturing function.
auto sg = c10::make_scope_exit([&]() { lock.lock(); });
lock.unlock();
p.err = cudaMallocMaybeCapturing(&ptr, size);
}
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() &&
p.pool->owner_PrivatePool) {
ptr = active_pool->allocator()->raw_alloc(size);
p.err = ptr ? cudaSuccess : cudaErrorMemoryAllocation;
} else {
p.err = cudaMallocMaybeCapturing(&ptr, size);
}
Expand Down
3 changes: 3 additions & 0 deletions docs/source/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ Memory management
change_current_allocator
MemPool
MemPoolContext

.. autoclass:: torch.cuda.use_mem_pool

.. FIXME The following doesn't seem to exist. Is it supposed to?
https://github.com/pytorch/pytorch/issues/27785
.. autofunction:: reset_max_memory_reserved
Expand Down
33 changes: 31 additions & 2 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import contextlib
import ctypes
import gc
import json
import os
Expand Down Expand Up @@ -4806,10 +4807,25 @@ def test_mempool_with_allocator(self):

dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
C10_EXPORT int called_dummy_alloc = 0;
C10_EXPORT int called_dummy_free = 0;
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { return nullptr; }
C10_EXPORT void dummy_free(void* ptr) { }
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
called_dummy_alloc = 123;
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
called_dummy_free = 321;
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_libname = "dummy_allocator"
Expand All @@ -4819,6 +4835,7 @@ def test_mempool_with_allocator(self):
is_python_module=False,
keep_intermediates=False,
verbose=True,
with_cuda=True,
)
allocator = torch.cuda.memory.CUDAPluggableAllocator(
dummy_allocator,
Expand All @@ -4830,6 +4847,18 @@ def test_mempool_with_allocator(self):
# pool should point to the same allocator as the one passed into it
self.assertEqual(allocator.allocator(), pool.allocator)

# no allocations happened yet, so called_dummy_alloc should be 0
alloc_lib = ctypes.CDLL(dummy_allocator)
called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
self.assertEqual(called_dummy_alloc.value, 0)

with torch.cuda.use_mem_pool(pool):
out = torch.randn(1, device="cuda")

# called_dummy_alloc should be 123 if dummy_alloc was used to allocate
# out tensor
self.assertEqual(called_dummy_alloc.value, 123)

def test_mempool_context(self):
active_pool = torch.cuda.MemPoolContext.active_pool()

Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,7 @@ def _cuda_cudaHostAllocator() -> _int: ...
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/cuda/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,13 @@ static void registerCudaPluggableAllocator(PyObject* module) {
});
});

m.def(
"_cuda_beginAllocateToPool",
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
device, mempool_id, [](cudaStream_t) { return true; });
});

m.def(
"_cuda_endAllocateCurrentStreamToPool",
[](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) {
Expand Down
1 change: 1 addition & 0 deletions torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,6 +1628,7 @@ def addmm_kernel_impl(*args, **kwargs):
"memory_usage",
"MemPool",
"MemPoolContext",
"use_mem_pool",
"temperature",
"power_draw",
"clock_rate",
Expand Down
39 changes: 38 additions & 1 deletion torch/cuda/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"change_current_allocator",
"MemPool",
"MemPoolContext",
"use_mem_pool",
]


Expand All @@ -64,8 +65,20 @@
# Define dummy base classes
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
"_cuda_beginAllocateToPool"
)
torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type(
"_cuda_endAllocateCurrentStreamToPool"
)

from torch._C import _cuda_CUDAAllocator, _MemPool, _MemPoolContext # noqa: F401
from torch._C import ( # noqa: F401
_cuda_beginAllocateToPool,
_cuda_CUDAAllocator,
_cuda_endAllocateCurrentStreamToPool,
_MemPool,
_MemPoolContext,
)


def _host_allocator():
Expand Down Expand Up @@ -1002,3 +1015,27 @@ def __init__(self, pool: MemPool):
def active_pool() -> Optional[_MemPool]:
r"""Returns the active MemPool"""
return _MemPoolContext.active_pool()


@contextlib.contextmanager
def use_mem_pool(pool: MemPool, device: Union[Device, int] = None):
r"""A context manager that routes allocations to a given pool.
Args:
pool(torch.cuda.MemPool): a MemPool object to be made active so that
allocations route to this pool.
device (torch.device or int, optional): selected device. Uses MemPool on
the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
"""
ctx = MemPoolContext(pool)
device_index = (
torch.cuda.current_device() if device is None else _get_device_index(device)
)
_cuda_beginAllocateToPool(device_index, pool.id)
try:
yield
finally:
_cuda_endAllocateCurrentStreamToPool(device_index, pool.id)
del ctx

0 comments on commit 4655eb3

Please sign in to comment.