Skip to content

Commit

Permalink
Simple cut of the kernel in place
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Apr 9, 2019
1 parent 03100f4 commit e57f5d0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
7 changes: 7 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ void multi_tensor_axpby_cuda(
float b,
int arg_to_check);

at::Tensor multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
51 changes: 51 additions & 0 deletions csrc/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,54 @@ struct TypeShim
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}


template<typename T, typename ReduceOp>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes,
bool share_result) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.

if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}

#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}

T final;

if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();

#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}

if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}

return final;
}
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu'],
'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo',
'-O3',
Expand Down

0 comments on commit e57f5d0

Please sign in to comment.