Skip to content

Commit

Permalink
just use gemv for bias propagation
Browse files Browse the repository at this point in the history
which will be bit slow, but correct, previous one is incorrect.
  • Loading branch information
liuliu committed Jun 9, 2014
1 parent 90b0360 commit f031b56
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 76 deletions.
60 changes: 50 additions & 10 deletions bin/cuda/cwc-bench-runtime.cu
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cudaMallocHost(&seventh_back, sizeof(float) * sixth_out_rows * sixth_out_cols * sixth_out_channels * batch);
cudaMemcpy(seventh_back, GPU(convnet)->backwards[6], sizeof(float) * sixth_out_rows * sixth_out_cols * sixth_out_channels * batch, cudaMemcpyDeviceToHost);
float* seventh_grad = 0;
cudaMallocHost(&seventh_grad, sizeof(float) * seventh_gpu_layer->wnum);
cudaMallocHost(&seventh_grad, sizeof(float) * (seventh_gpu_layer->wnum + seventh_gpu_layer->net.convolutional.count));
assert(seventh_grad);
cudaMemcpy(seventh_grad, seventh_gpu_configuration->w, sizeof(float) * seventh_gpu_layer->wnum, cudaMemcpyDeviceToHost);
cudaMemcpy(seventh_grad, seventh_gpu_configuration->w, sizeof(float) * (seventh_gpu_layer->wnum + seventh_gpu_layer->net.convolutional.count), cudaMemcpyDeviceToHost);
printf("finished backward propagate seventh convolutional layer on GPU\n");

// sixth convolutonal layer backward propagate
Expand All @@ -194,9 +194,9 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cudaMallocHost(&sixth_back, sizeof(float) * fifth_out_rows * fifth_out_cols * fifth_out_channels * batch);
cudaMemcpy(sixth_back, GPU(convnet)->backwards[5], sizeof(float) * fifth_out_rows * fifth_out_cols * fifth_out_channels * batch, cudaMemcpyDeviceToHost);
float* sixth_grad = 0;
cudaMallocHost(&sixth_grad, sizeof(float) * sixth_gpu_layer->wnum);
cudaMallocHost(&sixth_grad, sizeof(float) * (sixth_gpu_layer->wnum + sixth_gpu_layer->net.convolutional.count));
assert(sixth_grad);
cudaMemcpy(sixth_grad, sixth_gpu_configuration->w, sizeof(float) * sixth_gpu_layer->wnum, cudaMemcpyDeviceToHost);
cudaMemcpy(sixth_grad, sixth_gpu_configuration->w, sizeof(float) * (sixth_gpu_layer->wnum + sixth_gpu_layer->net.convolutional.count), cudaMemcpyDeviceToHost);
printf("finished backward propagate sixth convolutional layer on GPU\n");

// fifth convolutonal layer backward propagate
Expand All @@ -219,9 +219,9 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cudaMallocHost(&fifth_back, sizeof(float) * forth_out_rows * forth_out_cols * forth_out_channels * batch);
cudaMemcpy(fifth_back, GPU(convnet)->backwards[4], sizeof(float) * forth_out_rows * forth_out_cols * forth_out_channels * batch, cudaMemcpyDeviceToHost);
float* fifth_grad = 0;
cudaMallocHost(&fifth_grad, sizeof(float) * fifth_gpu_layer->wnum);
cudaMallocHost(&fifth_grad, sizeof(float) * (fifth_gpu_layer->wnum + fifth_gpu_layer->net.convolutional.count));
assert(fifth_grad);
cudaMemcpy(fifth_grad, fifth_gpu_configuration->w, sizeof(float) * fifth_gpu_layer->wnum, cudaMemcpyDeviceToHost);
cudaMemcpy(fifth_grad, fifth_gpu_configuration->w, sizeof(float) * (fifth_gpu_layer->wnum + fifth_gpu_layer->net.convolutional.count), cudaMemcpyDeviceToHost);
printf("finished backward propagate fifth convolutional layer on GPU\n");

// third convolutonal layer backward propagate
Expand All @@ -245,9 +245,9 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cudaMallocHost(&third_back, sizeof(float) * second_out_rows * second_out_cols * second_out_channels * batch);
cudaMemcpy(third_back, GPU(convnet)->backwards[2], sizeof(float) * second_out_rows * second_out_cols * second_out_channels * batch, cudaMemcpyDeviceToHost);
float* third_grad = 0;
cudaMallocHost(&third_grad, sizeof(float) * third_gpu_layer->wnum);
cudaMallocHost(&third_grad, sizeof(float) * (third_gpu_layer->wnum + third_gpu_layer->net.convolutional.count));
assert(third_grad);
cudaMemcpy(third_grad, third_gpu_configuration->w, sizeof(float) * third_gpu_layer->wnum, cudaMemcpyDeviceToHost);
cudaMemcpy(third_grad, third_gpu_configuration->w, sizeof(float) * (third_gpu_layer->wnum + third_gpu_layer->net.convolutional.count), cudaMemcpyDeviceToHost);
printf("finished backward propagate third convolutional layer on GPU\n");

// second average pool layer backward propagate
Expand All @@ -273,9 +273,9 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cudaStreamSynchronize(context->device.stream);
assert(cudaGetLastError() == cudaSuccess);
float* first_grad = 0;
cudaMallocHost(&first_grad, sizeof(float) * first_gpu_layer->wnum);
cudaMallocHost(&first_grad, sizeof(float) * (first_gpu_layer->wnum + first_gpu_layer->net.convolutional.count));
assert(first_grad);
cudaMemcpy(first_grad, first_gpu_configuration->w, sizeof(float) * first_gpu_layer->wnum, cudaMemcpyDeviceToHost);
cudaMemcpy(first_grad, first_gpu_configuration->w, sizeof(float) * (first_gpu_layer->wnum + first_gpu_layer->net.convolutional.count), cudaMemcpyDeviceToHost);
printf("finished backward propagate first convolutional layer on GPU\n");
cudaEventDestroy(start);
cudaEventDestroy(stop);
Expand Down Expand Up @@ -472,6 +472,14 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
if (delta > 1e-4)
printf("conv bprop 7: %d %d %d %d: |%g - %g| = %g\n", x, y, k, c, p, q, delta);
}
for (k = 0; k < seventh_filter_count; k++)
{
float p = seventh_cpu_configuration->bias[k];
float q = seventh_grad[seventh_gpu_layer->wnum + k];
float delta = fabs(p - q) / ccv_max(ccv_max(fabs(p), fabs(q)), 1);
if (delta > 1e-4)
printf("conv bprop 7 bias: %d: |%g - %g| = %g\n", k, p, q, delta);
}

ccv_convnet_layer_t* sixth_cpu_configuration = update_params->layers + 5;
int sixth_filter_rows = sixth_gpu_layer->net.convolutional.rows;
Expand All @@ -489,6 +497,14 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
if (delta > 1e-3)
printf("conv bprop 6: %d %d %d %d: |%g - %g| = %g\n", x, y, k, c, p, q, delta);
}
for (k = 0; k < sixth_filter_count; k++)
{
float p = sixth_cpu_configuration->bias[k];
float q = sixth_grad[sixth_gpu_layer->wnum + k];
float delta = fabs(p - q) / ccv_max(ccv_max(fabs(p), fabs(q)), 1);
if (delta > 1e-4)
printf("conv bprop 6 bias: %d: |%g - %g| = %g\n", k, p, q, delta);
}

ccv_convnet_layer_t* fifth_cpu_configuration = update_params->layers + 4;
int fifth_filter_rows = fifth_gpu_layer->net.convolutional.rows;
Expand All @@ -506,6 +522,14 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
if (delta > 1e-2)
printf("conv bprop 5: %d %d %d %d: |%g - %g| = %g\n", x, y, k, c, p, q, delta);
}
for (k = 0; k < fifth_filter_count; k++)
{
float p = fifth_cpu_configuration->bias[k];
float q = fifth_grad[fifth_gpu_layer->wnum + k];
float delta = fabs(p - q) / ccv_max(ccv_max(fabs(p), fabs(q)), 1);
if (delta > 1e-4)
printf("conv bprop 5 bias: %d: |%g - %g| = %g\n", k, p, q, delta);
}

ccv_convnet_layer_t* third_cpu_configuration = update_params->layers + 2;
int third_filter_rows = third_gpu_layer->net.convolutional.rows;
Expand All @@ -523,6 +547,14 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
if (delta > 1e-4)
printf("conv bprop 3: %d %d %d %d: |%g - %g| = %g\n", x, y, k, c, p, q, delta);
}
for (k = 0; k < third_filter_count; k++)
{
float p = third_cpu_configuration->bias[k];
float q = third_grad[third_gpu_layer->wnum + k];
float delta = fabs(p - q) / ccv_max(ccv_max(fabs(p), fabs(q)), 1);
if (delta > 1e-4)
printf("conv bprop 3 bias: %d: |%g - %g| = %g\n", k, p, q, delta);
}

ccv_convnet_layer_t* first_cpu_configuration = update_params->layers;
int first_filter_rows = first_gpu_layer->net.convolutional.rows;
Expand All @@ -540,4 +572,12 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
if (delta > 1e-3)
printf("conv bprop 1: %d %d %d %d: |%g - %g| = %g\n", x, y, k, c, p, q, delta);
}
for (k = 0; k < first_filter_count; k++)
{
float p = first_cpu_configuration->bias[k];
float q = first_grad[first_gpu_layer->wnum + k];
float delta = fabs(p - q) / ccv_max(ccv_max(fabs(p), fabs(q)), 1);
if (delta > 1e-4)
printf("conv bprop 1 bias: %d: |%g - %g| = %g\n", k, p, q, delta);
}
}
2 changes: 1 addition & 1 deletion lib/ccv_convnet.c
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ static void _ccv_convnet_convolutional_backward_propagate(ccv_convnet_layer_t* l
np += n->cols * count;
mp += m->cols * ch * (ccv_max((i + 1) * strides - border, 0) - ccv_max(i * strides - border, 0));
}
update_params->bias[k] = bias;
update_params->bias[k] += bias;
} parallel_endfor
if (b)
{
Expand Down
71 changes: 6 additions & 65 deletions lib/cuda/cwc_convnet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ static void _cwc_convnet_make_unit(ccv_convnet_t* convnet, int batch)
_ccv_convnet_layer_derive_output(layers + i, layers[i].input.matrix.rows, layers[i].input.matrix.cols, &out_rows, &out_cols, &out_partition);
if (_cwc_convnet_layer_use_rows(layers + i))
unit_size = ccv_max(unit_size, out_rows * (batch / BATCH_PER_BLOCK));
unit_size = ccv_max(unit_size, out_rows * out_cols * batch);
}
float* unit = 0;
cudaMallocHost(&unit, sizeof(float) * unit_size);
Expand Down Expand Up @@ -1199,60 +1200,6 @@ __global__ static void _cwc_kern_convolutional_backward_propagate_coefficient_ro
coeff[(i + threadIdx.y * channel_per_thread) * cocnt + j * filter_cols * count + k + threadIdx.x * filter_per_thread] = prod[j][i][k];
}

template <int out_per_thread>
__global__ static void _cwc_kern_convolutional_backward_propagate_bias(const int batch,
float* out_grad, const int out_rows, const int out_cols,
float* bias, const int count)
{
assert(gridDim.x == count);
const int skip_pixels = blockDim.y;
extern __shared__ float shared[];
float* shared_bias = &shared[0];
float* shared_grad = &shared[1];
int i, x;
const int thidx = threadIdx.x + threadIdx.y * blockDim.x;
const int thcnt = blockDim.x * blockDim.y;
const int out_loads = (batch * skip_pixels + thcnt - 1) / thcnt;
assert(thcnt % batch == 0);
out_grad += blockIdx.x * out_rows * out_cols * batch + thidx;
const int out_load_factor = thcnt;
const int out_load_pixels = thcnt / batch;
if (thidx == 0)
shared_bias[0] = 0;
for (x = 0; x < out_rows * out_cols; x += skip_pixels)
{
for (i = 0; i < out_loads; i++)
if (i * thcnt + thidx < batch * skip_pixels && x + i * out_load_pixels < out_rows * out_cols)
shared_grad[i * thcnt + thidx] = out_grad[x * batch + i * out_load_factor];
__syncthreads();
// because I branched out with threadIdx, therefore, synchronization must happen outside of the if clause
if (threadIdx.y + x < out_rows * out_cols)
{
#pragma unroll
for (i = 1; i < out_per_thread; i++)
shared_grad[threadIdx.y * batch + threadIdx.x * out_per_thread] += shared_grad[threadIdx.y * batch + threadIdx.x * out_per_thread + i];
}
__syncthreads();
// I can do better here, but bias computation is not the bottleneck
if (threadIdx.y + x < out_rows * out_cols && threadIdx.x == 0)
#pragma unroll
for (i = 1; i < blockDim.x; i++)
shared_grad[threadIdx.y * batch] += shared_grad[threadIdx.y * batch + i * out_per_thread];
__syncthreads();
// because I branched out with threadIdx, therefore, synchronization must happen outside of the if clause, thus, this if clause appeared repeatedly
if (threadIdx.y + x < out_rows * out_cols && thidx == 0)
{
#pragma unroll
for (i = 1; i < blockDim.y && i + x < out_rows * out_cols; i++)
shared_grad[0] += shared_grad[i * batch];
shared_bias[0] += shared_grad[0];
}
__syncthreads();
}
if (thidx == 0)
bias[blockIdx.x] = shared_bias[0];
}

template <int input_per_thread, int channel_per_thread, int channel_per_block, int strides>
__global__ static void _cwc_kern_convolutional_backward_propagate_error(const int border, const int batch,
float* input_grad, const int rows, const int cols, const int channels,
Expand Down Expand Up @@ -1337,6 +1284,7 @@ __global__ static void _cwc_kern_reorder_matrix_major(float* a, float* b, const
a += blockIdx.z * count * channels_per_partition * batch;
b[(threadIdx.x * count + blockIdx.x) * channels_per_partition + blockIdx.y] = a[(blockIdx.y * count + blockIdx.x) * batch + threadIdx.x];
}

// this method rewinds a matrix
__global__ static void _cwc_kern_reorder_matrix_major_parted(float* a, float* b, const int count, const int channels, const int batch, const int channels_per_partition, const int batch_per_partition, const int partition)
{
Expand Down Expand Up @@ -1504,27 +1452,20 @@ static void _cwc_convnet_convolutional_backward_propagate(ccv_convnet_layer_t* l
{
assert(layer->net.convolutional.count % 4 == 0);
assert(batch % BATCH_PER_BLOCK == 0);
int out_rows, out_cols, out_partition, shared_memory_size;
int out_rows, out_cols, out_partition;
_ccv_convnet_layer_derive_output(layer, layer->input.matrix.rows, layer->input.matrix.cols, &out_rows, &out_cols, &out_partition);
// it turns out that first apply relu would save us a lot of computation because no need to low both out and out_grad any more
_cwc_kern_relu_backward_propagate
<<<dim3(out_cols, out_rows, layer->net.convolutional.count), batch, 0, stream>>>
(batch, n, a, out_rows, out_cols, layer->net.convolutional.count);
assert(cudaGetLastError() == cudaSuccess);
float alpha = 1, beta = 0;
if (_cwc_convnet_layer_use_rows(layer))
_cwc_convnet_convolutional_backward_propagate_coefficient_rows(layer, batch, a, n, m, b, configuration, scratch, unit, stream, handle);
else
_cwc_convnet_convolutional_backward_propagate_coefficient_default(layer, batch, a, n, m, b, configuration, scratch, unit, stream, handle);
dim3 threads_per_block_for_bias(batch / 16, 16);
assert(threads_per_block_for_bias.x * threads_per_block_for_bias.y <= 1024);
dim3 num_blocks_for_bias(layer->net.convolutional.count);
shared_memory_size = sizeof(float) * (1 + batch * 16);
_cwc_kern_convolutional_backward_propagate_bias
<16>
<<<num_blocks_for_bias, threads_per_block_for_bias, shared_memory_size, stream>>>
(batch,
a, out_rows, out_cols,
configuration->bias, layer->net.convolutional.count);
// compute the bias directly using gemv routine
cublasSgemv(handle, CUBLAS_OP_T, out_rows * out_cols * batch, layer->net.convolutional.count, &alpha, a, out_rows * out_cols * batch, unit, 1, &beta, configuration->bias, 1);
assert(cudaGetLastError() == cudaSuccess);
if (b)
_cwc_convnet_convolutional_backward_propagate_error(layer, batch, a, n, m, b, configuration, scratch, unit, stream, handle);
Expand Down

0 comments on commit f031b56

Please sign in to comment.