Skip to content

Commit

Permalink
spell explicitly uint to unsigned int
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts committed May 27, 2024
1 parent f27ca4d commit c4c985c
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions dev/cuda/layernorm_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ version 2 moves a lot of reduction to shared memory over global memory

#define ENABLE_BF16
#include "common.h"
// typedef unsigned int uint; // required on windows

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down Expand Up @@ -428,33 +428,33 @@ __global__ void layernorm_backward_kernel4(Tdinp* dinp, Tparams* dweight, Tparam
__nv_bfloat162 new_dweight = add_dweight + current_dweight;

// Write the result back to L2 cache using 32-bit integer atomic compare and exchange
uint current_dbias32b = *reinterpret_cast<uint*>(&current_dbias);
uint current_dweight32b = *reinterpret_cast<uint*>(&current_dweight);
unsigned int current_dbias32b = *reinterpret_cast<unsigned int*>(&current_dbias);
unsigned int current_dweight32b = *reinterpret_cast<unsigned int*>(&current_dweight);

uint new_dbias32b = *reinterpret_cast<uint*>(&new_dbias);
uint new_dweight32b = *reinterpret_cast<uint*>(&new_dweight);
unsigned int new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);
unsigned int new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);

uint old_dbias32b = atomicCAS((uint*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
uint old_dweight32b = atomicCAS((uint*)&dweightVec2[i], current_dweight32b, new_dweight32b);
unsigned int old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
unsigned int old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);

// If the value has changed between read and atomic, we need to try again
while (old_dbias32b != current_dbias32b) {
current_dbias32b = old_dbias32b;
new_dbias = *reinterpret_cast<__nv_bfloat162*>(&current_dbias32b) + add_dbias;
new_dbias32b = *reinterpret_cast<uint*>(&new_dbias);
old_dbias32b = atomicCAS((uint*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);
old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);
}

while (old_dweight32b != current_dweight32b) {
current_dweight32b = old_dweight32b;
new_dweight = *reinterpret_cast<__nv_bfloat162*>(&current_dweight32b) + add_dweight;
new_dweight32b = *reinterpret_cast<uint*>(&new_dweight);
old_dweight32b = atomicCAS((uint*)&dweightVec2[i], current_dweight32b, new_dweight32b);
new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);
old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);
}
}
}

// FP32 scratchpad per threadgroup, zero atomics except atomicAdd on uint for the flag (based on kernel3)
// FP32 scratchpad per threadgroup, zero atomics except atomicAdd on unsigned int for the flag (based on kernel3)
template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
__global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
Expand All @@ -476,7 +476,7 @@ __global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparam
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
uint *tmp_flag = (uint*)(shared + C*2);
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();

int warps_in_grid = gridDim.x * warp.meta_group_size();
Expand Down Expand Up @@ -526,7 +526,7 @@ __global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparam

float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C * gridDim.x;
uint* scratchFlag = (uint*)(scratch + (2 * C * gridDim.x));
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C * gridDim.x));

for(int i = threadIdx.x; i < C; i+= blockDim.x) {
scratch_dbias[i + C*blockIdx.x] = dbias_shared[i];
Expand Down Expand Up @@ -576,7 +576,7 @@ __global__ void layernorm_backward_kernel6(Tdinp* dinp, Tparams* dweight, Tparam
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
uint *tmp_flag = (uint*)(shared + C*2);
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();

int warps_in_grid = gridDim.x * warp.meta_group_size();
Expand Down Expand Up @@ -628,7 +628,7 @@ __global__ void layernorm_backward_kernel6(Tdinp* dinp, Tparams* dweight, Tparam
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
uint* scratchFlag = (uint*)(scratch + (2 * C));
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAdd(&scratch_dbias[i], dbias_shared[i]);
atomicAdd(&scratch_dweight[i], dweight_shared[i]);
Expand Down Expand Up @@ -669,7 +669,7 @@ __global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
uint *tmp_flag = (uint*)(shared + C*2);
unsigned int *tmp_flag = (unsigned int*)(shared + C*2);
__syncthreads();

for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {
Expand Down Expand Up @@ -721,7 +721,7 @@ __global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX
__syncthreads();
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
uint* scratchFlag = (uint*)(scratch + (2 * C));
unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));
for(int i = threadIdx.x; i < C; i+= blockDim.x) {
atomicAdd(&scratch_dbias[i], dbias_shared[i]);
atomicAdd(&scratch_dweight[i], dweight_shared[i]);
Expand Down

0 comments on commit c4c985c

Please sign in to comment.