Skip to content

Commit

Permalink
Improve the efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts committed Apr 11, 2024
1 parent 00010c7 commit d6bcc12
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions dev/cuda/softmax_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ void softmax_forward_online_cpu(float* out, float* inp, int N, int C) {

float maxval = -INFINITY;
float sum = 0.0f;
for (int j = 0; j < C; j++) {
float maxval_prev = maxval;
if (inp_row[j] > maxval) {
maxval = inp_row[j];
sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
} else {
for (int j = 0; j < C; j++) {
float maxval_prev = maxval;
if (inp_row[j] > maxval) {
maxval = inp_row[j];
sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
}
else {
sum += expf(inp_row[j] - maxval);
}
}
}

for (int j = 0; j < C; j++) {
out_row[j] = expf(inp_row[j] - maxval) / sum;
Expand Down Expand Up @@ -332,14 +333,14 @@ __global__ void softmax_forward_online_kernel1(float* out, float* inp, int N, in
float sum = 0.0f;
for (int j = 0; j < C; j++) {
float maxval_prev = maxval;
if (inp_row[j] > maxval) {
maxval = inp_row[j];
sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
}
else {
sum += expf(inp_row[j] - maxval);
}
}
if (inp_row[j] > maxval) {
maxval = inp_row[j];
sum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);
}
else {
sum += expf(inp_row[j] - maxval);
}
}

for (int j = 0; j < C; j++) {
out_row[j] = expf(inp_row[j] - maxval) / sum;
Expand Down

0 comments on commit d6bcc12

Please sign in to comment.