Skip to content

Commit

Permalink
Make torch.cat not synchronize the host and device
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored and soumith committed May 10, 2017
1 parent 5f308b5 commit d5e8210
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 29 deletions.
22 changes: 21 additions & 1 deletion THCGeneral.c
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,27 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
return allocator->free(allocator->state, ptr);
}

void* THCudaHostAlloc(THCState *state, size_t size)
{
THCudaCheck(cudaGetLastError());
THAllocator* allocator = state->cudaHostAllocator;
return allocator->malloc(NULL, size);
}

void THCudaHostFree(THCState *state, void *ptr)
{
THAllocator* allocator = state->cudaHostAllocator;
return allocator->free(NULL, ptr);
}

void THCudaHostRecord(THCState *state, void *ptr)
{
if (state->cudaHostAllocator == &THCCachingHostAllocator) {
THCStream* stream = THCState_getStream(state);
THCCachingHostAllocator_recordEvent(ptr, stream);
}
}

cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
{
size_t cachedBytes = 0;
Expand Down Expand Up @@ -932,4 +953,3 @@ float THC_half2float(half h)
TH_halfbits2float(&h.x, &f);
return f;
}

4 changes: 4 additions & 0 deletions THCGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ THC_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const

THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
THC_API void* THCudaHostAlloc(THCState *state, size_t size);
THC_API void THCudaHostFree(THCState *state, void *ptr);
THC_API void THCudaHostRecord(THCState *state, void *ptr);

THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
THC_API void THCSetGCHandler(THCState *state,
void (*torchGCHandlerFunction)(void *data),
Expand Down
51 changes: 23 additions & 28 deletions generic/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,23 +175,9 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
real *data = THCTensor_(data)(state, result);

// Kernel Parameter
CatArrInputTensor<real, unsigned int> stackInputs[CAT_ARRAY_BATCH_SIZE];
CatArrInputTensor<real, unsigned int> *d_inputs;

// Attempt to re-use stream's scratch space for the input metadata
bool usedScratch = false;
size_t tensorMetadataSize = sizeof(CatArrInputTensor<real, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
if (THCState_getCurrentDeviceScratchSpaceSize(state) > tensorMetadataSize) {
void* space = THCState_getCurrentDeviceScratchSpace(state);
if (space) {
d_inputs = (CatArrInputTensor<real, unsigned int> *) space;
usedScratch = true;
}
}
if (!usedScratch) {
// Fallback to allocating GPU memory
THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize));
}
CatArrInputTensor<real, unsigned int> *d_inputs;
THCudaCheck(THCudaMalloc(state, (void**) &d_inputs, tensorMetadataSize));

OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;

Expand All @@ -201,13 +187,17 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
param.outputStride[i] = THCTensor_(stride)(state, result, i);
}

THCStream* stream = THCState_getStream(state);

// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, THCState_getCurrentStream(state)>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);
CatArrayBatchedCopy<real, unsigned int, DIMS><<<applyGrid, applyBlock, 0, stream->stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]);

// Now we loop
offset = 0;
for (i = 0; i < numInputs; i += CAT_ARRAY_BATCH_SIZE) {
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
CatArrInputTensor<real, unsigned int>* stackInputs = (CatArrInputTensor<real, unsigned int>*) THCudaHostAlloc(state, tensorMetadataSize);
cohortMax = 0;
for (j = 0; j < CAT_ARRAY_BATCH_SIZE && (i+j) < numInputs; ++j) {
long dimSize = cat_dimension < THCTensor_(nDimension)(state, inputs[i+j])
Expand All @@ -223,7 +213,14 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
// update offset
offset += dimSize;
}
THCudaCheck(cudaMemcpy(d_inputs, stackInputs, j * sizeof(CatArrInputTensor<real, unsigned int>), cudaMemcpyHostToDevice));
THCudaCheck(cudaMemcpyAsync(
d_inputs,
stackInputs,
j * sizeof(CatArrInputTensor<real, unsigned int>),
cudaMemcpyHostToDevice,
stream->stream));
THCudaHostRecord(state, stackInputs);
THCudaHostFree(state, stackInputs);

// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
Expand Down Expand Up @@ -256,9 +253,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
}
THCudaCheck(cudaGetLastError());
}
if (!usedScratch) {
THCudaCheck(THCudaFree(state, (void *)d_inputs));
}
THCudaCheck(THCudaFree(state, d_inputs));
#undef HANDLE_CASE
} else {
offset = 0;
Expand Down Expand Up @@ -399,10 +394,10 @@ void THCTensor_(linspace)(THCState *state, THCTensor *r_, real a, real b, long n
if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
if (n == 1) THCTensor_(fill)(state, r_, a);
else {
THCTensor *r = THCTensor_(isContiguous)(state, r_)
THCTensor *r = THCTensor_(isContiguous)(state, r_)
? r_ // if r_ is contiguous we can direct work on it
: THCTensor_(newContiguous)(state, r_);
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
ScalarConvert<long,real>::to(n - 1));
LinspaceOp<real> linspace_method(a, step);
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));
Expand All @@ -420,10 +415,10 @@ void THCTensor_(logspace)(THCState *state, THCTensor *r_, real a, real b, long n
if (THCTensor_(nElement)(state, r_) != n) THCTensor_(resize1d)(state, r_, n);
if (n == 1) THCTensor_(fill)(state, r_, THCNumerics<real>::exp10(a));
else {
THCTensor *r = THCTensor_(isContiguous)(state, r_)
? r_
THCTensor *r = THCTensor_(isContiguous)(state, r_)
? r_
: THCTensor_(newContiguous)(state, r_);
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
real step = THCNumerics<real>::div(THCNumerics<real>::sub(b, a),
ScalarConvert<long,real>::to(n - 1));
LogspaceOp<real> logspace_method(a, step);
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));
Expand All @@ -444,8 +439,8 @@ void THCTensor_(range)(THCState *state, THCTensor *r_, accreal xmin, accreal xma
, 2, "upper bound and larger bound incoherent with step sign");
ptrdiff_t size = (ptrdiff_t) (((xmax - xmin) / step) + 1);
if (THCTensor_(nElement)(state, r_) != size) THCTensor_(resize1d)(state, r_, size);
THCTensor *r = THCTensor_(isContiguous)(state, r_)
? r_
THCTensor *r = THCTensor_(isContiguous)(state, r_)
? r_
: THCTensor_(newContiguous)(state, r_);
LinspaceOp<real,accreal> linspace_method(xmin, step);
thrust::device_ptr<real> data_(THCTensor_(data)(state, r));
Expand Down

0 comments on commit d5e8210

Please sign in to comment.