Skip to content

Commit

Permalink
cutorch gc
Browse files Browse the repository at this point in the history
  • Loading branch information
adamlerer committed Aug 19, 2015
1 parent a16af4b commit 843374d
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 29 deletions.
19 changes: 19 additions & 0 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,24 @@ static int cutorch_getState(lua_State *L)
return 1;
}

static void luaCutorchGCFunction(void *data)
{
lua_State *L = data;
lua_gc(L, LUA_GCCOLLECT, 0);
}

static int cutorch_setHeapTracking(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int enabled = luaT_checkboolean(L,1);
if(enabled) {
THCSetGCHandler(state, luaCutorchGCFunction, L);
} else {
THCSetGCHandler(state, NULL, NULL);
}
return 0;
}

static const struct luaL_Reg cutorch_stuff__ [] = {
{"synchronize", cutorch_synchronize},
{"reserveBlasHandles", cutorch_reserveBlasHandles},
Expand Down Expand Up @@ -753,6 +771,7 @@ static const struct luaL_Reg cutorch_stuff__ [] = {
{"getRNGState", cutorch_getRNGState},
{"setRNGState", cutorch_setRNGState},
{"getState", cutorch_getState},
{"setHeapTracking", cutorch_setHeapTracking},
{NULL, NULL}
};

Expand Down
2 changes: 2 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ function cutorch.createCudaHostTensor(...)
return torch.FloatTensor(storage, 1, size:storage())
end

cutorch.setHeapTracking(true)

return cutorch
64 changes: 61 additions & 3 deletions lib/THC/THCGeneral.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void THCudaInit(THCState* state)

/* Allocate scratch space for each stream */
res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
THCudaCheck(cudaMalloc(&res->devScratchSpacePerStream[0],
THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
sizePerStream));
}

Expand All @@ -71,6 +71,10 @@ void THCudaInit(THCState* state)
THCState_reserveBlasHandles(state, 1);
state->currentPerDeviceBlasHandle = 1;
state->currentBlasHandle = THCState_getDeviceBlasHandle(state, device, 1);

state->cutorchGCFunction = NULL;
state->cutorchGCData = NULL;
state->heapSoftmax = 300000000; // 300MB, adjusted upward dynamically
}

void THCudaShutdown(THCState* state)
Expand Down Expand Up @@ -100,7 +104,7 @@ void THCudaShutdown(THCState* state)
/* Free per-stream scratch space; starts at 0 because there is space for
the default stream as well*/
for (int stream = 0; stream <= state->numUserStreams; ++stream) {
THCudaCheck(cudaFree(THCState_getDeviceScratchSpace(state, dev, stream)));
THCudaCheck(THCudaFree(state, THCState_getDeviceScratchSpace(state, dev, stream)));
}

free(state->resourcesPerDevice[dev].streams);
Expand Down Expand Up @@ -199,7 +203,7 @@ void THCState_reserveStreams(THCState* state, int numStreams)
newStreams[stream] = NULL;
THCudaCheck(cudaStreamCreate(newStreams + stream));
newScratchSpace[stream] = NULL;
THCudaCheck(cudaMalloc(&newScratchSpace[stream], scratchSpaceSize));
THCudaCheck(THCudaMalloc(state, &newScratchSpace[stream], scratchSpaceSize));
}

THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev);
Expand Down Expand Up @@ -489,4 +493,58 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
}
}

static long heapSize = 0; // not thread-local
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%

void THCSetGCHandler(THCState *state, void (*cutorchGCFunction_)(void *data), void *data )
{
state->cutorchGCFunction = cutorchGCFunction_;
state->cutorchGCData = data;
}

cudaError_t THCudaMalloc(THCState *state, void** ptr, size_t size)
{
THCudaCheck(cudaGetLastError());
cudaError_t err = cudaMalloc(ptr, size);
if (state->cutorchGCFunction != NULL && err != cudaSuccess) {
cudaGetLastError(); // reset OOM error
(state->cutorchGCFunction)(state->cutorchGCData);
err = cudaMalloc(ptr, size);
}
return err;
}

cudaError_t THCudaFree(THCState *state, void *ptr)
{
cudaError_t err = cudaFree(ptr);
return err;
}

// Here we maintain a dynamic softmax threshold for THC-allocated storages.
// When THC heap size goes above this softmax, the GC hook is triggered.
// If heap size is above 80% of the softmax after GC, then the softmax is
// increased.
static void maybeTriggerGC(THCState *state, long curHeapSize) {
if (state->cutorchGCFunction != NULL && curHeapSize > state->heapSoftmax) {
(state->cutorchGCFunction)(state->cutorchGCData);
long newHeapSize = THAtomicGetLong(&heapSize);
if (newHeapSize > state->heapSoftmax * heapSoftmaxGrowthThresh) {
state->heapSoftmax = state->heapSoftmax * heapSoftmaxGrowthFactor;
}
}
}

void THCHeapUpdate(THCState *state, long size) {
long newHeapSize = THAtomicAddLong(&heapSize, size) + size;
#ifdef THC_CHECK_HEAP_UPDATE
if (newHeapSize < 0) {
THError("Internal error: THC heapSize < 0");
}
#endif
if (size > 0) {
maybeTriggerGC(state, newHeapSize);
}
}

#undef GLOBAL_SCRATCH_SPACE_PER_SM_STREAM
11 changes: 11 additions & 0 deletions lib/THC/THCGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ typedef struct THCState
int currentPerDeviceBlasHandle;
/* Allocator using cudaMallocHost. */
THAllocator* cudaHostAllocator;

void (*cutorchGCFunction)(void *data);
void *cutorchGCData;
long heapSoftmax;
} THCState;

THC_API void THCudaInit(THCState* state);
Expand Down Expand Up @@ -116,4 +120,11 @@ THC_API size_t THCState_getDeviceScratchSpaceSize(THCState* state, int device);
THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);

THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
THC_API void THCSetGCHandler(THCState *state,
void (*torchGCHandlerFunction)(void *data),
void *data );
THC_API void THCHeapUpdate(THCState *state, long size);

#endif
13 changes: 11 additions & 2 deletions lib/THC/THCStorage.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ THCudaStorage* THCudaStorage_newWithSize(THCState *state, long size)
if(size > 0)
{
THCudaStorage *storage = (THCudaStorage*)THAlloc(sizeof(THCudaStorage));
THCudaCheck(cudaMalloc((void**)&(storage->data), size * sizeof(float)));

// update heap *before* attempting malloc, to free space for the malloc
THCHeapUpdate(state, size * sizeof(float));
cudaError_t err =
THCudaMalloc(state, (void**)&(storage->data), size * sizeof(float));
if(err != cudaSuccess){
THCHeapUpdate(state, -size * sizeof(float));
}
THCudaCheck(err);

storage->size = size;
storage->refcount = 1;
Expand Down Expand Up @@ -110,7 +118,8 @@ void THCudaStorage_free(THCState *state, THCudaStorage *self)
if (THAtomicDecrementRef(&self->refcount))
{
if(self->flag & TH_STORAGE_FREEMEM) {
THCudaCheck(cudaFree(self->data));
THCHeapUpdate(state, -self->size * sizeof(float));
THCudaCheck(THCudaFree(state, self->data));
}
THFree(self);
}
Expand Down
14 changes: 11 additions & 3 deletions lib/THC/THCStorage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,31 @@ void THCudaStorage_resize(THCState *state, THCudaStorage *self, long size)
if(size == 0)
{
if(self->flag & TH_STORAGE_FREEMEM) {
THCudaCheck(cudaFree(self->data));
THCudaCheck(THCudaFree(state, self->data));
THCHeapUpdate(state, -self->size * sizeof(float));
}
self->data = NULL;
self->size = 0;
}
else
{
float *data = NULL;
THCudaCheck(cudaMalloc((void**)(&data), size * sizeof(float)));
// update heap *before* attempting malloc, to free space for the malloc
THCHeapUpdate(state, size * sizeof(float));
cudaError_t err = THCudaMalloc(state, (void**)(&data), size * sizeof(float));
if(err != cudaSuccess) {
THCHeapUpdate(state, -size * sizeof(float));
}
THCudaCheck(err);

if (self->data) {
THCudaCheck(cudaMemcpyAsync(data,
self->data,
THMin(self->size, size) * sizeof(float),
cudaMemcpyDeviceToDevice,
THCState_getCurrentStream(state)));
THCudaCheck(cudaFree(self->data));
THCudaCheck(THCudaFree(state, self->data));
THCHeapUpdate(state, -self->size * sizeof(float));
}

self->data = data;
Expand Down
12 changes: 6 additions & 6 deletions lib/THC/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THCuda
dim3 nthreads(16, 16);
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));

THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long)));
THCudaCheck(THCudaMalloc(state, (void**)&stride_, res_->nDimension * sizeof(long)));
THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice));

THCudaTensor_kernel_indexCopy<<<nblocks, nthreads, 0, THCState_getCurrentStream(state)>>>(
Expand All @@ -125,7 +125,7 @@ void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THCuda
THCudaTensor_nElement(state, src), res_->size[dim]
);

THCudaCheck(cudaFree(stride_));
THCudaCheck(THCudaFree(state, stride_));
THCudaTensor_free(state, indices);
THCudaTensor_free(state, src);
}
Expand Down Expand Up @@ -159,15 +159,15 @@ void THCudaTensor_indexFill(THCState *state, THCudaTensor *res_, int dim, THCuda
dim3 nthreads(16, 16);
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));

THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long)));
THCudaCheck(THCudaMalloc(state, (void**)&stride_, res_->nDimension * sizeof(long)));
THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice));

THCudaTensor_kernel_indexFill<<<nblocks, nthreads, 0, THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, res_), stride_, THCudaTensor_data(state, indices),
res_->nDimension, dim, nIndex, nRes, res_->size[dim], val
);

THCudaCheck(cudaFree(stride_));
THCudaCheck(THCudaFree(state, stride_));
THCudaTensor_free(state, indices);
}

Expand Down Expand Up @@ -299,7 +299,7 @@ void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor
dim3 nthreads(16, 16);
dim3 nblocks(ceil((float)nRes / nIndex / (16*16)));

THCudaCheck(cudaMalloc((void**)&stride_, src->nDimension * sizeof(long)));
THCudaCheck(THCudaMalloc(state, (void**)&stride_, src->nDimension * sizeof(long)));
THCudaCheck(cudaMemcpy(stride_, src->stride, src->nDimension * sizeof(long), cudaMemcpyHostToDevice));

THCudaTensor_kernel_indexSelect<<<nblocks, nthreads, 0, stream>>>(
Expand All @@ -308,7 +308,7 @@ void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor
src->nDimension, dim, nIndex, nRes, src->size[dim]
);

THCudaCheck(cudaFree(stride_));
THCudaCheck(THCudaFree(state, stride_));
THCudaTensor_free(state, indices);
THCudaTensor_freeCopyTo(state, res, res_);
}
12 changes: 6 additions & 6 deletions lib/THC/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THC
// Copy pointers to device.
const float **d_matrices1, **d_matrices2;
float **d_result_matrices;
THCudaCheck(cudaMalloc((void**)&d_matrices1, matrices_size));
THCudaCheck(cudaMalloc((void**)&d_matrices2, matrices_size));
THCudaCheck(cudaMalloc((void**)&d_result_matrices, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_result_matrices, matrices_size));

THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
Expand All @@ -366,9 +366,9 @@ void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THC
d_result_matrices, ldc,
num_batches);

cudaFree(d_matrices1);
cudaFree(d_matrices2);
cudaFree(d_result_matrices);
THCudaFree(state, d_matrices1);
THCudaFree(state, d_matrices2);
THCudaFree(state, d_result_matrices);
THFree(matrices1);
THFree(matrices2);
THFree(result_matrices);
Expand Down
18 changes: 9 additions & 9 deletions lib/THC/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
#define BLOCK_SIZE 256

/* Sets up generator. Allocates but does not create the generator states. */
__host__ void initializeGenerator(Generator* gen)
__host__ void initializeGenerator(THCState *state, Generator* gen)
{
THCudaCheck(cudaMalloc((void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
THCudaCheck(cudaMalloc((void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
THCudaCheck(THCudaMalloc(state, (void**)&gen->gen_states, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32)));
THCudaCheck(THCudaMalloc(state, (void**)&gen->kernel_params, sizeof(mtgp32_kernel_params)));
}

/* Frees memory allocated during setup. */
__host__ void destroyGenerator(Generator* gen)
__host__ void destroyGenerator(THCState *state, Generator* gen)
{
if (gen->gen_states)
{
THCudaCheck(cudaFree(gen->gen_states));
THCudaCheck(THCudaFree(state, gen->gen_states));
gen->gen_states = NULL;
}
if (gen->kernel_params)
{
THCudaCheck(cudaFree(gen->kernel_params));
THCudaCheck(THCudaFree(state, gen->kernel_params));
gen->kernel_params = NULL;
}
}
Expand Down Expand Up @@ -66,7 +66,7 @@ __host__ void THCRandom_init(THCState* state, int devices, int current_device)
rng_state->current_gen = &rng_state->gen[current_device];
// Initialize the generator for the current device. Other generators will be
// initialized on-demand in THCRandom_setGenerator.
initializeGenerator(rng_state->current_gen);
initializeGenerator(state, rng_state->current_gen);
THCRandom_seed(state);
}

Expand All @@ -77,7 +77,7 @@ __host__ void THCRandom_shutdown(THCState* state)
if (rng_state->gen == NULL) return;
for (int i = 0; i < rng_state->num_devices; ++i)
{
destroyGenerator(&rng_state->gen[i]);
destroyGenerator(state, &rng_state->gen[i]);
}
free(rng_state->gen);
rng_state->gen = NULL;
Expand All @@ -92,7 +92,7 @@ __host__ void THCRandom_setGenerator(THCState* state, int device)
rng_state->current_gen = &rng_state->gen[device];
if (rng_state->current_gen->initf == 0)
{
initializeGenerator(rng_state->current_gen);
initializeGenerator(state, rng_state->current_gen);
THCRandom_seed(state);
}
}
Expand Down

0 comments on commit 843374d

Please sign in to comment.