Skip to content

Commit

Permalink
Add CUDA caching allocator accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Mar 8, 2017
1 parent d4c2b1d commit f7c6799
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ With the caching memory allocator, device allocations and frees should logically
- `cutorch.getState()` - Returns the global state of the cutorch package. This state is not for users, it stores the raw RNG states, cublas handles and other thread and device-specific stuff.
- `cutorch.withDevice(devID, f)` - This is a convenience for multi-GPU code, that takes in a device ID as well as a function f. It switches cutorch to the new device, executes the function f, and switches back cutorch to the original device.
- `cutorch.createCudaHostTensor([...])` - Allocates a `torch.FloatTensor` of [host-pinned memory](https://devblogs.nvidia.com/parallelforall/how-optimize-data-transfers-cuda-cc/), where dimensions can be given as an argument list of sizes or a `torch.LongStorage`.
- `cutorch.isCachingAllocatorEnabled()` - Returns whether the caching CUDA memory allocator is enabled or not.

#### Low-level streams functions (dont use this as a user, easy to shoot yourself in the foot):
- `cutorch.reserveStreams(n [, nonblocking])`: creates n user streams for use on every device. NOTE: stream index `s` on device 1 is a different cudaStream_t than stream `s` on device 2. Takes an optional non-blocking flag; by default, this is assumed to be false. If true, then the stream is created with cudaStreamNonBlocking.
Expand Down
9 changes: 9 additions & 0 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,14 @@ static int cutorch_setKernelPeerToPeerAccess(lua_State *L)
return 0;
}

static int cutorch_isCachingAllocatorEnabled(lua_State *L)
{
THCState *state = cutorch_getstate(L);
lua_pushboolean(L, THCState_isCachingAllocatorEnabled(state));

return 1;
}

static int cutorch_getMemoryUsage(lua_State *L) {
size_t freeBytes = 0;
size_t totalBytes = 0;
Expand Down Expand Up @@ -993,6 +1001,7 @@ static const struct luaL_Reg cutorch_stuff__ [] = {
{"setPeerToPeerAccess", cutorch_setPeerToPeerAccess},
{"setKernelPeerToPeerAccess", cutorch_setKernelPeerToPeerAccess},
{"getKernelPeerToPeerAccess", cutorch_getKernelPeerToPeerAccess},
{"isCachingAllocatorEnabled", cutorch_isCachingAllocatorEnabled},
{"getDeviceProperties", cutorch_getDeviceProperties},
{"getRuntimeVersion", cutorch_getRuntimeVersion},
{"getDriverVersion", cutorch_getDriverVersion},
Expand Down
4 changes: 4 additions & 0 deletions lib/THC/THCGeneral.c
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator)
state->cudaDeviceAllocator = allocator;
}

int THCState_isCachingAllocatorEnabled(THCState* state) {
return state->cudaHostAllocator == &THCCachingHostAllocator;
}

int THCState_getNumDevices(THCState *state)
{
return state->numDevices;
Expand Down
1 change: 1 addition & 0 deletions lib/THC/THCGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ THC_API THAllocator* THCState_getCudaHostAllocator(THCState* state);
THC_API THAllocator* THCState_getCudaUVAAllocator(THCState* state);
THC_API THCDeviceAllocator* THCState_getDeviceAllocator(THCState* state);
THC_API void THCState_setDeviceAllocator(THCState* state, THCDeviceAllocator* allocator);
THC_API int THCState_isCachingAllocatorEnabled(THCState* state);

THC_API void THCMagma_init(THCState *state);

Expand Down

0 comments on commit f7c6799

Please sign in to comment.