Skip to content

Commit

Permalink
Improve operator-level qb4w test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Jul 25, 2024
1 parent dbac1ea commit d3d40e1
Showing 3 changed files with 558 additions and 98 deletions.
292 changes: 292 additions & 0 deletions src/operators/fully-connected-nc.c
Original file line number Diff line number Diff line change
@@ -951,6 +951,298 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w(
/*weights_cache=*/weights_cache, fully_connected_op_out);
}

enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w(
size_t input_channels,
size_t output_channels,
size_t input_stride,
size_t output_stride,
size_t block_size,
uint8_t kernel_zero_point,
const float* kernel_scale,
const void* kernel,
const float* bias,
float output_min,
float output_max,
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
xnn_operator_t* fully_connected_op_out)
{
if (isnan(output_min)) {
xnn_log_error(
"failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
return xnn_status_invalid_parameter;
}

if (isnan(output_max)) {
xnn_log_error(
"failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
return xnn_status_invalid_parameter;
}

if (output_min > output_max) {
xnn_log_error(
"failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), output_min, output_max);
return xnn_status_invalid_parameter;
}

// TODO: Better way to do this?
size_t max_kr = 8; // MAX_KR
if (round_up_po2(input_channels, max_kr) % block_size != 0) {
xnn_log_error(
"failed to create %s operator with input_channels: %zu, and block_size: %zu: expecting input_channels %% block_size == 0.",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), input_channels, block_size);
}

if (block_size < max_kr || block_size % max_kr != 0 ) {
xnn_log_error(
"failed to create %s operator with block_size: %zu: expecting block_size to be > %zu, and multiple of it.",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), block_size, max_kr);
}

// Assuming kernel_scale.size() is output_channels * num_blocks
size_t num_blocks = round_up_po2(input_channels, max_kr) / block_size;
for (size_t output_channel = 0; output_channel < output_channels; output_channel++) {
for(size_t block_index=0; block_index < num_blocks; block_index++) {
size_t scale_index = output_channel * num_blocks + block_index;
if (kernel_scale[scale_index] <= 0.0f || !isnormal(kernel_scale[scale_index])) {
xnn_log_error(
"failed to create %s operator with %.7g kernel scale in output channel #%zu, block #%zu: scale must be finite and positive",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w),
kernel_scale[scale_index], output_channel, block_index);
return xnn_status_invalid_parameter;
}
}
}

const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qb4w_gemm_config();
if (gemm_config == NULL) {
xnn_log_error("failed to create %s operator: unsupported hardware configuration",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
return xnn_status_unsupported_hardware;
}

const struct gemm_fused_ukernels* gemm_ukernels = &gemm_config->minmax;
const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
if (linear_activation && gemm_config->linear.gemm[gemm_config->mr-1].function[XNN_UARCH_DEFAULT] != NULL) {
gemm_ukernels = &gemm_config->linear;
}

union xnn_f32_qb4w_minmax_params params;
if XNN_LIKELY(gemm_config->init.f32_qb4w != NULL) {
gemm_config->init.f32_qb4w(&params, output_min, output_max, kernel_zero_point, block_size);
}

// We don't know input zero point until runtime, row sum is multiplied by it during packing, so set it to 1.
const struct xnn_qs8_qc4w_packing_params packing_params = { /*input_zero_point=*/1, kernel_zero_point };

xnn_operator_t fully_connected_op = NULL;
enum xnn_status status = xnn_status_uninitialized;

if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
goto error;
}

status = xnn_status_invalid_parameter;

if (input_channels == 0) {
xnn_log_error(
"failed to create %s operator with %zu input channels: number of channels must be non-zero",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), input_channels);
goto error;
}

if (output_channels == 0) {
xnn_log_error(
"failed to create %s operator with %zu output channels: number of channels must be non-zero",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), output_channels);
goto error;
}

if (input_stride < input_channels) {
xnn_log_error(
"failed to create %s operator with input element stride of %zu: "
"stride must be at least as large as the number of input channels (%zu)",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), input_stride, input_channels);
goto error;
}

if (output_stride < output_channels) {
xnn_log_error(
"failed to create %s operator with output element stride of %zu: "
"stride must be at least as large as the number of output channels (%zu)",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), output_stride, output_channels);
goto error;
}

status = xnn_status_out_of_memory;

fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
if (fully_connected_op == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator descriptor",
sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
goto error;
}

fully_connected_op->weights_cache = weights_cache;
fully_connected_op->code_cache = code_cache;

const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t planes = gemm_config->planes;

const size_t n_stride = round_up(output_channels, nr);

size_t k_stride = round_up_po2(input_channels, kr * sr);

bool filter_is_nibble = true;
if (filter_is_nibble) {
input_channels = round_up_po2(input_channels, planes);

if (planes < 1 || planes > 2) {
xnn_log_error(
"planes is %u but expected to be 1 or 2 for 4 bit", planes);
goto error;
}
k_stride = round_up_po2(input_channels, kr * sr * planes);

// If filter is 4-bit, half k_stride (since we will scale k_stride by log2_filter_element_size, and we pass 0 for qc4).
k_stride = round_up_po2(k_stride, 2) >> 1;
}

// Per input_channel
size_t block_scale_bytes = 0;
if (block_size != 0) {
block_scale_bytes += num_blocks * sizeof(float);
}

const size_t packed_weights_size =
n_stride * (sizeof(float) + k_stride + sizeof(float) + block_scale_bytes);

size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);

uint32_t cache_seed = output_channels ^ input_channels ^ nr ^ kr ^ sr ^ sizeof(float) ^ xnn_operator_type_fully_connected_nc_qd8_f32_qb4w;
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
cache_seed = ~cache_seed;
}
size_t cache_offset = XNN_CACHE_NOT_FOUND;
struct xnn_weights_cache_look_up_key cache_key;
cache_key.seed = cache_seed;
cache_key.kernel = kernel;
cache_key.bias = bias;
if (use_weights_cache(fully_connected_op)) {
cache_offset = xnn_weights_cache_look_up(
fully_connected_op->weights_cache, &cache_key);
}

if (cache_offset == XNN_CACHE_NOT_FOUND) {
void* weights_ptr = xnn_get_pointer_to_write_weights(
fully_connected_op, aligned_total_weights_size, /* packed_weights_padding_byte = */ 0);
if (weights_ptr == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator packed weights",
packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));
goto error;
}
xnn_log_debug("allocated %zu bytes for packed weights in %s operator",
aligned_total_weights_size, xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w));

// Pack weights
gemm_config->pack_gemm_goi_bl(
/*groups=*/1, output_channels, input_channels,
nr, kr, sr,
block_size,
kernel, /*bias=*/NULL, /*scale=*/kernel_scale,
weights_ptr,
/* extra_bytes_bl */ gemm_config->nr * sizeof(float), /* TODO: get this from the op */
/* extra_bytes_n */ gemm_config->nr * sizeof(float),
&packing_params);


// Fill in kernel scale
// Start
void* weights = (void*) ((uintptr_t) weights_ptr +
gemm_config->nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2)));

const size_t block_stride = /* weights */ block_size / 2 + sizeof(float);
const size_t weights_stride = /* weights */ k_stride * sizeof(int8_t) +
/* scales= */ num_blocks * sizeof(float) +
/* ksum= */ sizeof(float) +
/* bias= */ sizeof(float);

xnn_init_blockwise_scale_fp32_params(
output_channels, gemm_config->nr, gemm_config->nr,
gemm_config->nr * weights_stride,
gemm_config->nr * weights_stride,
/* num_blocks=*/ num_blocks,
/* block_stride=*/ gemm_config->nr * block_stride,
0,
kernel_scale, weights);

// Fill in bias
if (bias != NULL) {
weights = (void*) ((uintptr_t) weights_ptr + gemm_config->nr * (weights_stride - sizeof(float))) ;
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, gemm_config->nr, gemm_config->nr,
gemm_config->nr * weights_stride, gemm_config->nr * weights_stride, 0,
bias, weights);
}

if (use_weights_cache(fully_connected_op)) {
fully_connected_op->packed_weights.offset = xnn_look_up_or_insert_weights_cache(
fully_connected_op->weights_cache, &cache_key, weights_ptr, aligned_total_weights_size);
}
} else {
fully_connected_op->packed_weights.offset = cache_offset;
}

fully_connected_op->group_input_channels = input_channels;
fully_connected_op->group_output_channels = output_channels;
fully_connected_op->input_pixel_stride = input_stride;
fully_connected_op->output_pixel_stride = output_stride;
fully_connected_op->k_block_size = block_size;

memcpy(&fully_connected_op->params, &params, sizeof(params));
fully_connected_op->type = xnn_operator_type_fully_connected_nc_qd8_f32_qb4w;
fully_connected_op->flags = flags;

const size_t mr = gemm_config->mr;
fully_connected_op->ukernel.type = xnn_microkernel_type_gemm;
fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
.mr = mr,
.nr = nr,
.kr = kr,
.sr = sr,
.kp = planes,
};
assert(XNN_MAX_MR >= mr);
for (size_t i = 0; i < mr; i++) {
fully_connected_op->ukernel.gemm.gemm_cases[i] = gemm_ukernels->gemm[i];
}

#if XNN_PLATFORM_JIT
xnn_generate_gemms_up_to_max_mr(
mr, gemm_config->generator, /*jit_gemm_params=*/NULL, output_channels, nr,
input_channels << sizeof(int8_t), fully_connected_op);
#endif // XNN_PLATFORM_JIT

fully_connected_op->state = xnn_run_state_invalid;

*fully_connected_op_out = fully_connected_op;
return xnn_status_success;

error:
xnn_delete_operator(fully_connected_op);
return status;
}

enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w(
size_t input_channels,
size_t output_channels,
Loading

0 comments on commit d3d40e1

Please sign in to comment.