Skip to content

Commit

Permalink
Refactor global average pooling into reshape+setup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544558444
  • Loading branch information
ngzhian authored and xnnpack-bot committed Jun 30, 2023
1 parent df12900 commit f6768ed
Show file tree
Hide file tree
Showing 36 changed files with 1,504 additions and 492 deletions.
40 changes: 32 additions & 8 deletions bench/global-average-pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ static void global_average_pooling_qu8(benchmark::State& state) {
state.SkipWithError("failed to create Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_qu8(
status = xnn_reshape_global_average_pooling_nwc_qu8(
global_pooling_op,
batch_size, input_height * input_width,
input.data(), output.data(),
nullptr /* thread pool */);
if (status != xnn_status_success) {
state.SkipWithError("failed to reshape Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_qu8(
global_pooling_op,
input.data(), output.data());
if (status != xnn_status_success) {
state.SkipWithError("failed to setup Global Average Pooling operator");
}
Expand Down Expand Up @@ -111,11 +117,17 @@ static void global_average_pooling_qs8(benchmark::State& state) {
state.SkipWithError("failed to create Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_qs8(
status = xnn_reshape_global_average_pooling_nwc_qs8(
global_pooling_op,
batch_size, input_height * input_width,
input.data(), output.data(),
nullptr /* thread pool */);
if (status != xnn_status_success) {
state.SkipWithError("failed to reshape Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_qs8(
global_pooling_op,
input.data(), output.data());
if (status != xnn_status_success) {
state.SkipWithError("failed to setup Global Average Pooling operator");
}
Expand Down Expand Up @@ -170,11 +182,17 @@ static void global_average_pooling_f16(benchmark::State& state) {
state.SkipWithError("failed to create Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_f16(
status = xnn_reshape_global_average_pooling_nwc_f16(
global_pooling_op,
batch_size, input_height * input_width,
input.data(), output.data(),
nullptr /* thread pool */);
if (status != xnn_status_success) {
state.SkipWithError("failed to reshape Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_f16(
global_pooling_op,
input.data(), output.data());
if (status != xnn_status_success) {
state.SkipWithError("failed to setup Global Average Pooling operator");
}
Expand Down Expand Up @@ -228,11 +246,17 @@ static void global_average_pooling_f32(benchmark::State& state) {
state.SkipWithError("failed to create Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_f32(
status = xnn_reshape_global_average_pooling_nwc_f32(
global_pooling_op,
batch_size, input_height * input_width,
input.data(), output.data(),
nullptr /* thread pool */);
if (status != xnn_status_success) {
state.SkipWithError("failed to reshape Global Average Pooling operator");
}

status = xnn_setup_global_average_pooling_nwc_f32(
global_pooling_op,
input.data(), output.data());
if (status != xnn_status_success) {
state.SkipWithError("failed to setup Global Average Pooling operator");
}
Expand Down
72 changes: 48 additions & 24 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -2179,14 +2179,17 @@ enum xnn_status xnn_create_global_average_pooling_nwc_f32(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
enum xnn_status xnn_reshape_global_average_pooling_nwc_f32(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const float* input,
float* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
xnn_operator_t global_average_pooling_op,
const float* input,
float* output);

enum xnn_status xnn_create_global_sum_pooling_nwc_f32(
size_t channels,
size_t input_stride,
Expand All @@ -2196,14 +2199,17 @@ enum xnn_status xnn_create_global_sum_pooling_nwc_f32(
uint32_t flags,
xnn_operator_t* global_sum_pooling_op_out);

enum xnn_status xnn_setup_global_sum_pooling_nwc_f32(
enum xnn_status xnn_reshape_global_sum_pooling_nwc_f32(
xnn_operator_t global_sum_pooling_op,
size_t batch_size,
size_t width,
const float* input,
float* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_sum_pooling_nwc_f32(
xnn_operator_t global_sum_pooling_op,
const float* input,
float* output);

enum xnn_status xnn_create_hardswish_nc_f32(
size_t channels,
size_t input_stride,
Expand Down Expand Up @@ -2765,14 +2771,17 @@ enum xnn_status xnn_create_global_average_pooling_ncw_f32(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_ncw_f32(
enum xnn_status xnn_reshape_global_average_pooling_ncw_f32(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const float* input,
float* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_ncw_f32(
xnn_operator_t global_average_pooling_op,
const float* input,
float* output);

enum xnn_status xnn_create_resize_bilinear2d_nchw_f32(
size_t channels,
size_t input_pixel_stride,
Expand Down Expand Up @@ -3290,14 +3299,17 @@ enum xnn_status xnn_create_global_average_pooling_nwc_f16(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_nwc_f16(
enum xnn_status xnn_reshape_global_average_pooling_nwc_f16(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const void* input,
void* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_nwc_f16(
xnn_operator_t global_average_pooling_op,
const void* input,
void* output);

enum xnn_status xnn_create_global_sum_pooling_nwc_f16(
size_t channels,
size_t input_stride,
Expand All @@ -3307,14 +3319,17 @@ enum xnn_status xnn_create_global_sum_pooling_nwc_f16(
uint32_t flags,
xnn_operator_t* global_sum_pooling_op_out);

enum xnn_status xnn_setup_global_sum_pooling_nwc_f16(
enum xnn_status xnn_reshape_global_sum_pooling_nwc_f16(
xnn_operator_t global_sum_pooling_op,
size_t batch_size,
size_t width,
const void* input,
void* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_sum_pooling_nwc_f16(
xnn_operator_t global_sum_pooling_op,
const void* input,
void* output);

enum xnn_status xnn_create_hardswish_nc_f16(
size_t channels,
size_t input_stride,
Expand Down Expand Up @@ -3717,14 +3732,17 @@ enum xnn_status xnn_create_global_average_pooling_ncw_f16(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_ncw_f16(
enum xnn_status xnn_reshape_global_average_pooling_ncw_f16(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const void* input,
void* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_ncw_f16(
xnn_operator_t global_average_pooling_op,
const void* input,
void* output);

enum xnn_status xnn_create_resize_bilinear2d_nchw_f16(
size_t channels,
size_t input_pixel_stride,
Expand Down Expand Up @@ -4107,14 +4125,17 @@ enum xnn_status xnn_create_global_average_pooling_nwc_qs8(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_nwc_qs8(
enum xnn_status xnn_reshape_global_average_pooling_nwc_qs8(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const int8_t* input,
int8_t* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_nwc_qs8(
xnn_operator_t global_average_pooling_op,
const int8_t* input,
int8_t* output);

enum xnn_status xnn_create_multiply_nd_qs8(
int8_t input1_zero_point,
float input1_scale,
Expand Down Expand Up @@ -4485,14 +4506,17 @@ enum xnn_status xnn_create_global_average_pooling_nwc_qu8(
uint32_t flags,
xnn_operator_t* global_average_pooling_op_out);

enum xnn_status xnn_setup_global_average_pooling_nwc_qu8(
enum xnn_status xnn_reshape_global_average_pooling_nwc_qu8(
xnn_operator_t global_average_pooling_op,
size_t batch_size,
size_t width,
const uint8_t* input,
uint8_t* output,
pthreadpool_t threadpool);

enum xnn_status xnn_setup_global_average_pooling_nwc_qu8(
xnn_operator_t global_average_pooling_op,
const uint8_t* input,
uint8_t* output);

enum xnn_status xnn_create_leaky_relu_nc_qu8(
size_t channels,
size_t input_stride,
Expand Down
13 changes: 10 additions & 3 deletions models/fp16-mobilenet-v1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,15 @@ ExecutionPlan FP16MobileNetV1(bool use_jit, pthreadpool_t threadpool) {
return ExecutionPlan();
}

status = xnn_reshape_global_average_pooling_nwc_f16(
op27,
/*batch_size=*/1, 49 /* width */,
/*threadpool=*/threadpool);
if (status != xnn_status_success) {
std::cerr << "failed to reshape operation #27" << std::endl;
return ExecutionPlan();
}

status = xnn_reshape_convolution2d_nhwc_f16(
op28,
/*batch_size=*/1, /*input_height=*/1, /*input_width=*/1,
Expand Down Expand Up @@ -1402,9 +1411,7 @@ ExecutionPlan FP16MobileNetV1(bool use_jit, pthreadpool_t threadpool) {

status = xnn_setup_global_average_pooling_nwc_f16(
op27,
/*batch_size=*/1, 49 /* width */,
/*input=*/v27.data(), /*output=*/v28.data(),
/*threadpool=*/threadpool);
/*input=*/v27.data(), /*output=*/v28.data());
if (status != xnn_status_success) {
std::cerr << "failed to setup operation #27" << std::endl;
return ExecutionPlan();
Expand Down
13 changes: 10 additions & 3 deletions models/fp16-mobilenet-v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,15 @@ ExecutionPlan FP16MobileNetV2(bool use_jit, pthreadpool_t threadpool) {
return ExecutionPlan();
}

status = xnn_reshape_global_average_pooling_nwc_f16(
op62,
/*batch_size=*/1, 49 /* width */,
/*threadpool=*/threadpool);
if (status != xnn_status_success) {
std::cerr << "failed to reshape operation #62" << std::endl;
return ExecutionPlan();
}

status = xnn_reshape_convolution2d_nhwc_f16(
op63,
/*batch_size=*/1, /*input_height=*/1, /*input_width=*/1,
Expand Down Expand Up @@ -2942,9 +2951,7 @@ ExecutionPlan FP16MobileNetV2(bool use_jit, pthreadpool_t threadpool) {

status = xnn_setup_global_average_pooling_nwc_f16(
op62,
/*batch_size=*/1, 49 /* width */,
/*input=*/v62.data(), /*output=*/v63.data(),
/*threadpool=*/threadpool);
/*input=*/v62.data(), /*output=*/v63.data());
if (status != xnn_status_success) {
std::cerr << "failed to setup operation #62" << std::endl;
return ExecutionPlan();
Expand Down
Loading

0 comments on commit f6768ed

Please sign in to comment.