Skip to content

Commit

Permalink
[GPU] Quantize scale shift opt optimizations (openvinotoolkit#7770)
Browse files Browse the repository at this point in the history
  • Loading branch information
lznamens authored Oct 7, 2021
1 parent 623117f commit 94d5d81
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ CommonDispatchData QuantizeKernelScaleShift::SetDefault(const quantize_params& p
dispatchData.lws[0] = 1;
dispatchData.lws[1] = sub_group_size;
dispatchData.lws[2] = 1;
} else if (output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv32) {
} else if (output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv32 || output.GetLayout() == DataLayout::bs_fs_yx_bsv16_fsv16 ||
output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv16) {
dispatchData.gws[0] = output.Y().v * output.X().v;
dispatchData.gws[1] = Align(output.Feature().v, feature_size);
dispatchData.gws[2] = Align(output.Batch().v, feature_size);
Expand All @@ -63,8 +64,9 @@ CommonDispatchData QuantizeKernelScaleShift::SetDefault(const quantize_params& p
JitConstants QuantizeKernelScaleShift::GetJitConstants(const quantize_params& params, const CommonDispatchData& dispatchData) const {
JitConstants jit = Parent::GetJitConstants(params, dispatchData);

if (params.output.GetLayout() == DataLayout::b_fs_yx_fsv16 ||
params.output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv32) {
if (params.output.GetLayout() == DataLayout::b_fs_yx_fsv16 || params.output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv32 ||
params.output.GetLayout() == DataLayout::bs_fs_yx_bsv16_fsv16 || params.output.GetLayout() == DataLayout::bs_fs_yx_bsv32_fsv16) {
jit.AddConstant(MakeJitConstant("FEATURE_BLOCKED_FORMAT", true));
jit.AddConstant(MakeJitConstant("GWS_BATCH", 2));
jit.AddConstant(MakeJitConstant("GWS_FEATURE", 1));
jit.AddConstant(MakeJitConstant("GWS_YX", 0));
Expand All @@ -74,21 +76,31 @@ JitConstants QuantizeKernelScaleShift::GetJitConstants(const quantize_params& pa
jit.Merge(tensor_jits);
}

auto can_use_output_range = params.per_tensor_output_range && params.out_lo < params.out_hi;
auto has_output_range_round = !(params.output.GetDType() == Datatype::INT8 || params.output.GetDType() == Datatype::UINT8);

jit.AddConstant(MakeJitConstant("HAS_POST_SCALE", params.has_post_scale));
jit.AddConstant(MakeJitConstant("HAS_POST_SHIFT", params.has_post_shift));
jit.AddConstant(MakeJitConstant("HAS_PRE_SHIFT", params.has_pre_shift));
jit.AddConstant(MakeJitConstant("HAS_CLAMP", params.has_clamp));
jit.AddConstant(MakeJitConstant("HAS_MIN_CLAMP", params.has_min_clamp));
jit.AddConstant(MakeJitConstant("HAS_MAX_CLAMP", params.has_max_clamp));
jit.AddConstant(MakeJitConstant("PER_TENSOR_INPUT_RANGE", params.per_tensor_input_range));
jit.AddConstant(MakeJitConstant("PER_TENSOR_OUTPUT_RANGE", params.per_tensor_output_range));
jit.AddConstant(MakeJitConstant("PER_TENSOR_INPUT_SCALE", params.per_tensor_input_scale));
jit.AddConstant(MakeJitConstant("PER_TENSOR_INPUT_SHIFT", params.per_tensor_input_shift));
jit.AddConstant(MakeJitConstant("PER_TENSOR_OUTPUT_SCALE", params.per_tensor_output_scale));
jit.AddConstant(MakeJitConstant("PER_TENSOR_OUTPUT_SHIFT", params.per_tensor_output_shift));
jit.AddConstant(MakeJitConstant("IN_LO_VAL", params.in_lo));
jit.AddConstant(MakeJitConstant("IN_HI_VAL", params.in_hi));
jit.AddConstant(MakeJitConstant("OUT_LO_VAL", params.out_lo));
jit.AddConstant(MakeJitConstant("OUT_HI_VAL", params.out_hi));
jit.AddConstant(MakeJitConstant("IN_SCALE_VAL", params.in_scale));
jit.AddConstant(MakeJitConstant("IN_SHIFT_VAL", params.in_shift));
jit.AddConstant(MakeJitConstant("OUT_SCALE_VAL", params.out_scale));
jit.AddConstant(MakeJitConstant("OUT_SHIFT_VAL", params.out_shift));
jit.AddConstant(MakeJitConstant("CAN_USE_OUTPUT_RANGE", can_use_output_range));
jit.AddConstant(MakeJitConstant("HAS_OUTPUT_RANGE_ROUND", has_output_range_round));

return jit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "include/batch_headers/data_types.cl"
#include "include/batch_headers/fetch_data.cl"

#define TO_OUTPUT_TYPE CAT(convert_, OUTPUT_TYPE)
#define TO_OUTPUT_TYPE_SAT_RTE CAT(TO_OUTPUT_TYPE, _sat_rte)

#ifdef SUB_GROUP_SIZE
__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))
#endif
Expand All @@ -22,63 +25,77 @@ KERNEL(quantize_gpu_scale_shift_opt)(const __global INPUT0_TYPE* input,
{
const int b = get_global_id(GWS_BATCH);
const int of = get_global_id(GWS_FEATURE);

#if OUTPUT_DIMS <= 4
const int yx = get_global_id(GWS_YX);

const int x = yx % OUTPUT_SIZE_X;
const int y = yx / OUTPUT_SIZE_X;
const int z = 0;

const int output_offset = OUTPUT_GET_INDEX(b, of, y, x);
#elif OUTPUT_DIMS == 5
const int zyx = get_global_id(GWS_YX);
const int zyx_div_x = zyx / OUTPUT_SIZE_X;

const int x = zyx % OUTPUT_SIZE_X;
const int y = (zyx / OUTPUT_SIZE_X) % OUTPUT_SIZE_Y;
const int z = (zyx / OUTPUT_SIZE_X) / OUTPUT_SIZE_Y;
const int y = zyx_div_x % OUTPUT_SIZE_Y;
const int z = zyx_div_x / OUTPUT_SIZE_Y;

const int output_offset = OUTPUT_GET_INDEX(b, of, z, y, x);
#elif OUTPUT_DIMS == 6
const int wzyx = get_global_id(GWS_YX);
const int wzyx_div_x = wzyx / OUTPUT_SIZE_X;
const int wzyx_div_xy = wzyx_div_x / OUTPUT_SIZE_Y;

const int x = wzyx % OUTPUT_SIZE_X;
const int y = (wzyx / OUTPUT_SIZE_X) % OUTPUT_SIZE_Y;
const int z = ((wzyx / OUTPUT_SIZE_X) / OUTPUT_SIZE_Y) % OUTPUT_SIZE_Z;
const int w = ((wzyx / OUTPUT_SIZE_X) / OUTPUT_SIZE_Y) / OUTPUT_SIZE_Z;
const int y = wzyx_div_x % OUTPUT_SIZE_Y;
const int z = wzyx_div_xy % OUTPUT_SIZE_Z;
const int w = wzyx_div_xy / OUTPUT_SIZE_Z;

const int output_offset = OUTPUT_GET_INDEX(b, of, w, z, y, x);
#else
# error quantize_gpu_scale_shift_opt.cl: output tensors with more than 6 dimensions are unsupported
#endif

#if INPUT0_DIMS == 6
const int input_offset = INPUT0_GET_INDEX(b, of, w, z, y, x);
#if INPUT0_DIMS <= 4
const int input_offset = INPUT0_GET_INDEX(b, of, y, x);
#elif INPUT0_DIMS == 5
const int input_offset = INPUT0_GET_INDEX(b, of, z, y, x);
#elif INPUT0_DIMS <= 4
const int input_offset = INPUT0_GET_INDEX(b, of, y, x);
#endif

#if OUTPUT_DIMS == 6
const int output_offset = OUTPUT_GET_INDEX(b, of, w, z, y, x);
#elif OUTPUT_DIMS == 5
const int output_offset = OUTPUT_GET_INDEX(b, of, z, y, x);
#elif OUTPUT_DIMS <= 4
const int output_offset = OUTPUT_GET_INDEX(b, of, y, x);
#elif INPUT0_DIMS == 6
const int input_offset = INPUT0_GET_INDEX(b, of, w, z, y, x);
#else
# error quantize_gpu_scale_shift_opt.cl: input tensors with more than 6 dimensions are unsupported
#endif

#if HAS_CLAMP && !PER_TENSOR_INPUT_RANGE
#if HAS_CLAMP && !PER_TENSOR_INPUT_RANGE && !CAN_USE_OUTPUT_RANGE
#if INPUT1_DIMS == 4
const int in_range_offset = INPUT1_GET_INDEX_SAFE(b, of, y, x);
#elif INPUT1_DIMS == 5
const int in_range_offset = INPUT1_GET_INDEX_SAFE(b, of, z, y, x);
#elif INPUT1_DIMS == 6
const int in_range_offset = INPUT1_GET_INDEX_SAFE(b, of, w, z, y, x);
#else
# error quantize_gpu_scale_shift_opt.cl: unsupported INPUT1_DIMS size
#endif
#endif
#endif // HAS_CLAMP && !PER_TENSOR_INPUT_RANGE && !CAN_USE_OUTPUT_RANGE

#if INPUT7_DIMS == 4
const int scales_offset = INPUT7_GET_INDEX_SAFE(b, of, y, x);
#elif INPUT7_DIMS == 5
const int scales_offset = INPUT7_GET_INDEX_SAFE(b, of, z, y, x);
#elif INPUT7_DIMS == 6
const int scales_offset = INPUT7_GET_INDEX_SAFE(b, of, w, z, y, x);
#else
# error quantize_gpu_scale_shift_opt.cl: unsupported INPUT7_DIMS size
#endif

#if PER_TENSOR_INPUT_SCALE
INPUT1_TYPE input_scale_val = IN_SCALE_VAL;
#else
INPUT1_TYPE input_scale_val = input_scale[scales_offset];
#endif

#if PER_TENSOR_INPUT_SHIFT
INPUT1_TYPE input_shift_val = IN_SHIFT_VAL;
#else
Expand All @@ -97,38 +114,96 @@ KERNEL(quantize_gpu_scale_shift_opt)(const __global INPUT0_TYPE* input,
INPUT1_TYPE output_shift_val = output_shift[scales_offset];
#endif

#if PER_TENSOR_INPUT_RANGE && HAS_CLAMP
#if HAS_CLAMP
#if CAN_USE_OUTPUT_RANGE
INPUT1_TYPE output_low_val = OUT_LO_VAL;
INPUT1_TYPE output_high_val = OUT_HI_VAL;
#else
#if PER_TENSOR_INPUT_RANGE
INPUT1_TYPE input_low_val = IN_LO_VAL;
INPUT1_TYPE input_high_val = IN_HI_VAL;
#elif HAS_CLAMP
#else
INPUT1_TYPE input_low_val = input_low[in_range_offset];
INPUT1_TYPE input_high_val = input_high[in_range_offset];
#endif // PER_TENSOR_INPUT_RANGE
#endif // CAN_USE_OUTPUT_RANGE
#endif // HAS_CLAMP

// ************************************************************* //
// Calculations for optimized branch with the output range usage //
// ************************************************************* //

#if CAN_USE_OUTPUT_RANGE

#if HAS_PRE_SHIFT
INPUT1_TYPE val = TO_INPUT1_TYPE(input[input_offset]) * input_scale_val + input_shift_val;
#else
INPUT1_TYPE val = TO_INPUT1_TYPE(input[input_offset]) * input_scale_val;
#endif

#if HAS_OUTPUT_RANGE_ROUND
val = round(val);
#endif

#if HAS_POST_SCALE
val *= output_scale_val;
#endif

#if HAS_POST_SHIFT
val += output_shift_val;
#endif

#if HAS_CLAMP
INPUT1_TYPE val = min(max(TO_INPUT1_TYPE(input[input_offset]), input_low_val), input_high_val);
#if HAS_MIN_CLAMP && HAS_MAX_CLAMP
val = clamp(val, output_low_val, output_high_val);
#elif HAS_MIN_CLAMP
val = max(val, output_low_val);
#else // HAS_MAX_CLAMP
val = min(val, output_high_val);
#endif
#endif // HAS_CLAMP

// ************************************************************** //
// Calculations for alternative branch with the input range usage //
// ************************************************************** //

#else // CAN_USE_OUTPUT_RANGE

#if HAS_CLAMP
INPUT1_TYPE val = clamp(TO_INPUT1_TYPE(input[input_offset]), input_low_val, input_high_val);
#else
INPUT1_TYPE val = TO_INPUT1_TYPE(input[input_offset]);
#endif

#if HAS_PRE_SHIFT
val = round(val * input_scale_val + input_shift_val);
#else
val = round(val * input_scale_val);
#endif

#if HAS_POST_SCALE
val = val*output_scale_val;
val *= output_scale_val;
#endif

#if HAS_POST_SHIFT
val += output_shift_val;
#endif

#if OUTPUT_LAYOUT_B_FS_YX_FSV16
#endif // CAN_USE_OUTPUT_RANGE

// *********************************** //
// Common section with results writing //
// *********************************** //

#if FEATURE_BLOCKED_FORMAT
if (of < OUTPUT_FEATURE_NUM)
#endif
#if OUTPUT_IS_FP
output[output_offset] = TO_OUTPUT_TYPE_SAT(val);
output[output_offset] = TO_OUTPUT_TYPE_SAT(val);
#else
output[output_offset] = TO_OUTPUT_TYPE_SAT(round(val));
output[output_offset] = TO_OUTPUT_TYPE_SAT_RTE(val);
#endif
}

#undef TO_OUTPUT_TYPE
#undef TO_OUTPUT_TYPE_SAT_RTE

0 comments on commit 94d5d81

Please sign in to comment.