Skip to content

Commit

Permalink
Refactor subgraph checks for matching quantization params
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480112000
  • Loading branch information
ngzhian authored and xnnpack-bot committed Oct 10, 2022
1 parent f29cf7a commit 480eaae
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 119 deletions.
21 changes: 4 additions & 17 deletions src/subgraph/clamp.c
Original file line number Diff line number Diff line change
Expand Up @@ -245,23 +245,10 @@ enum xnn_status xnn_define_clamp(
}

#if !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)
if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_clamp), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_clamp, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_U8_OPERATORS) || !defined(XNN_NO_S8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/copy.c
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,10 @@ enum xnn_status xnn_define_copy(
}

#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_copy), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_copy), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_copy, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/depth-to-space.c
Original file line number Diff line number Diff line change
Expand Up @@ -248,23 +248,10 @@ enum xnn_status xnn_define_depth_to_space(
}

#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
if (compute_type == xnn_datatype_qint8 || compute_type == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_depth_to_space), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_depth_to_space), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_depth_to_space, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/max-pooling-2d.c
Original file line number Diff line number Diff line change
Expand Up @@ -365,23 +365,10 @@ enum xnn_status xnn_define_max_pooling_2d(
}

#if !defined(XNN_NO_S8_OPERATORS) || !defined(XNN_NO_U8_OPERATORS)
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_max_pooling_2d), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_max_pooling_2d), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_max_pooling_2d, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_S8_OPERATORS) || !defined(XNN_NO_U8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/static-constant-pad.c
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,10 @@ enum xnn_status xnn_define_static_constant_pad(
}

#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_static_constant_pad, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/static-reshape.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,10 @@ enum xnn_status xnn_define_static_reshape(
}

#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_static_reshape, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)

Expand Down
21 changes: 4 additions & 17 deletions src/subgraph/static-resize-bilinear-2d.c
Original file line number Diff line number Diff line change
Expand Up @@ -297,23 +297,10 @@ enum xnn_status xnn_define_static_resize_bilinear_2d(
}

#if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
status = xnn_subgraph_check_quantization_parameter_matches(
xnn_node_type_static_resize_bilinear_2d, input_id, input_value, output_id, output_value);
if (status != xnn_status_success) {
return status;
}
#endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)

Expand Down
28 changes: 28 additions & 0 deletions src/subgraph/validation.c
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,31 @@ enum xnn_status xnn_subgraph_check_output_min_max(enum xnn_node_type node_type,
}
return xnn_status_success;
}

enum xnn_status xnn_subgraph_check_quantization_parameter_matches(
enum xnn_node_type node_type,
uint32_t input_id,
const struct xnn_value* input_value,
uint32_t output_id,
const struct xnn_value* output_value)
{
if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
xnn_node_type_to_string(node_type), input_id, output_id,
input_value->quantization.zero_point, output_value->quantization.zero_point);
return xnn_status_invalid_parameter;
}
if (input_value->quantization.scale != output_value->quantization.scale) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
": mismatching scale quantization parameter across input (%.7g) and output (%.7g)",
xnn_node_type_to_string(node_type), input_id, output_id,
input_value->quantization.scale, output_value->quantization.scale);
return xnn_status_invalid_parameter;
}
}
return xnn_status_success;
}
7 changes: 7 additions & 0 deletions src/xnnpack/subgraph-validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ enum xnn_status xnn_subgraph_check_datatype_matches_two_inputs(
const struct xnn_value* output_value);
enum xnn_status xnn_subgraph_check_output_min_max(enum xnn_node_type node_type, float output_min, float output_max);

enum xnn_status xnn_subgraph_check_quantization_parameter_matches(
enum xnn_node_type node_type,
uint32_t input_id,
const struct xnn_value* input_value,
uint32_t output_id,
const struct xnn_value* output_value);

#ifdef __cplusplus
} // extern "C"
#endif

0 comments on commit 480eaae

Please sign in to comment.