forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial light and dynamic convolution kernels (facebookresearch#547)
Summary: CUDA code for light/dynamicconv kernels, including pytorch modules. Modules can be built by running setup.py in each respective folder, and can then be imported and used like any other module. Pull Request resolved: fairinternal/fairseq-py#547 Reviewed By: myleott, shubho Differential Revision: D15703660 Pulled By: nng555 fbshipit-source-id: e9c913753be3a1cd571965f7200df6678b644520
- Loading branch information
1 parent
b870468
commit f840564
Showing
23 changed files
with
1,958 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
/** | ||
* Copyright (c) 2018-present, Facebook, Inc. | ||
* All rights reserved. | ||
* | ||
*/ | ||
|
||
|
||
template <typename U, typename V> | ||
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { | ||
return (a + b - 1) / b; | ||
} | ||
|
||
|
||
template<int FS, int SB, int padding_l, typename scalar_t> | ||
__inline__ __device__ | ||
void zeroSharedMem(scalar_t* data) { | ||
/* | ||
Given an array of length FS + SB, zero out the first padding_l and last | ||
(FS - padding_l) values in the array | ||
*/ | ||
|
||
int tid = threadIdx.x; | ||
|
||
if (FS < SB) { | ||
|
||
// zero all if we have enough threads in a block to do all of them | ||
if (tid < padding_l || tid > SB - FS + padding_l - 1) { | ||
data[tid] = scalar_t(0.0); | ||
} | ||
} else { | ||
|
||
// otherwise zero out one block at a time | ||
const int numIterations = divUp<int, int>(FS, SB); | ||
for (int i = 0; i < numIterations; i++) { | ||
int offset = i * SB; | ||
if (tid + offset < padding_l) { | ||
data[tid + offset] = scalar_t(0.0); | ||
} else if (tid + offset < FS) { | ||
data[SB + tid + offset] = scalar_t(0.0); | ||
} | ||
} | ||
} | ||
} | ||
|
||
template<typename scalar_t> | ||
__inline__ __device__ | ||
scalar_t warpReduce(scalar_t data) { | ||
/* | ||
Reduce an array within each warp. After processing all values in warp will | ||
caontain the sum of all original values in that warp. | ||
data - pointer to data to reduce | ||
*/ | ||
data += __shfl_xor_sync(SHFL_MASK, data, 16); | ||
data += __shfl_xor_sync(SHFL_MASK, data, 8); | ||
data += __shfl_xor_sync(SHFL_MASK, data, 4); | ||
data += __shfl_xor_sync(SHFL_MASK, data, 2); | ||
data += __shfl_xor_sync(SHFL_MASK, data, 1); | ||
return data; | ||
} | ||
|
||
template<typename scalar_t> | ||
__inline__ __device__ | ||
scalar_t blockReduce(scalar_t data) { | ||
/* | ||
Reduce an entire array on the block level. After processing, the | ||
first value in the array will contain the reduced sum. | ||
data - pointer to data to reduce | ||
*/ | ||
|
||
static __shared__ scalar_t warpSum[32]; | ||
const int tid = threadIdx.x; | ||
int wid = tid / 32; | ||
int lane = tid % 32; | ||
|
||
__syncthreads(); | ||
|
||
// reduce each warp then write to shared memory | ||
scalar_t sum = warpReduce(data); | ||
if (lane == 0) { | ||
warpSum[wid] = sum; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
scalar_t v; | ||
// perform final sum of partial warp sums | ||
if (tid < blockDim.x / 32) { | ||
v = warpSum[lane]; | ||
} else { | ||
v = scalar_t(0.0); | ||
} | ||
|
||
if (wid == 0) { | ||
v = warpReduce(v); | ||
} | ||
__syncthreads(); | ||
|
||
return v; | ||
} | ||
|
||
void checkCudaStatus(cudaError_t status, int lineNumber = -1) { | ||
|
||
if (status != cudaSuccess) { | ||
std::cout << cudaGetErrorString(status) | ||
<< " at line " << lineNumber << std::endl; | ||
std::cout << "Exiting" << std::endl; | ||
exit(1); | ||
} | ||
} | ||
|
||
template<int FS, int SB, int padding_l, typename scalar_t> | ||
__device__ | ||
void load_input_to_shared(const scalar_t* input, // global memory | ||
int inputOffset, int sequenceLength, | ||
int iteration, int numIterations, | ||
bool no_prev, scalar_t* output /* shared memory */) { | ||
/* | ||
Load a block size of input into shared memory with | ||
right and left overhang of total size FS. If previously | ||
loaded memory, overlap will be shifted over to reduce | ||
global memory access | ||
input - pointer to start of channel sequence | ||
inputOffset - how far in the sequence to start loading | ||
sequenceLength - total length of sequence | ||
iteration - which block of sequence we are loading | ||
numIterations - total number of blocks to load | ||
no_prev - whether to load the whole block if the previous block | ||
wasn't loaded | ||
output - shared memory to write input to | ||
*/ | ||
|
||
const int tid = threadIdx.x; | ||
|
||
// Load the left "overhang" of input | ||
if (iteration > 0) { | ||
if (padding_l < SB) { | ||
|
||
// load all at once | ||
if (tid < padding_l) { | ||
output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; | ||
} | ||
} else { | ||
|
||
// load in chunks of size SB | ||
int numIterations = divUp<int, int>(padding_l, SB); | ||
for (int i = 0; i < numIterations; i++) { | ||
int offset = i * SB; | ||
if ((tid + offset) < padding_l) { | ||
output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Load the right "overhang" of input | ||
if (iteration < (numIterations - 1)) { | ||
const int elementsLeft = sequenceLength - (iteration+1) * SB; | ||
|
||
if ((FS - padding_l) < SB) { | ||
|
||
// load all at once | ||
if (tid < (FS - padding_l)) { | ||
output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0); | ||
} | ||
} else { | ||
|
||
// load in chunks of size SB | ||
int numIterations = divUp<int, int>(FS - padding_l, SB); | ||
for (int i = 0; i < numIterations; i++) { | ||
int offset = i * SB; | ||
if ((tid + offset) < (FS - padding_l)) { | ||
output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// We should also clear out the right "overhang" | ||
if (iteration == (numIterations - 1)) { | ||
if ((FS - padding_l) < SB) { | ||
|
||
// clear out all at once | ||
if (tid < (FS - padding_l)) { | ||
output[padding_l + SB + tid] = scalar_t(0.0); | ||
} | ||
} else { | ||
|
||
// clear in chunks of size SB | ||
int numIterations = divUp<int, int>(FS - padding_l, SB); | ||
for (int i = 0; i < numIterations; i++) { | ||
int offset = i * SB; | ||
if ((tid + offset) < (FS - padding_l)) { | ||
output[padding_l + SB + tid + offset] = scalar_t(0.0); | ||
} | ||
} | ||
} | ||
} | ||
output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0); | ||
} |
Oops, something went wrong.