Skip to content
/ iree Public
forked from iree-org/iree

Commit

Permalink
e2e matmul test improvements (iree-org#19016)
Browse files Browse the repository at this point in the history
Working on iree-org#18980 let me spend
quality time with e2e matmul tests and suggested some changes.

The main change is to simplify the printing of numerical values to
always use high precision, meaning print all significant digits of
floating point values.

Since our tests generate small integral values and the intent is
generally to be testing mostly the exact arithmetic that happens on
small integral values, in most cases this doesn't make any difference.

But I found that RDNA3 float arithmetic produces non-exact results even
on those values. As a result, I got values like 1+epsilon where 1 was
expected, causing a test to fail (since we didn't know we needed to opt
out from requiring exact results) and the test output cryptically
printed both values as "1".

The other change is to more consistently print the same number of rows
and columns regardless of whether we are at the start or in the middle
of a dimension, and to have that number be what we call "context"
(before, it was "2 * context").

Also a seasonal emoji change.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 16, 2024
1 parent e3b6cc3 commit e10342d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 67 deletions.
72 changes: 33 additions & 39 deletions tools/testing/e2e/iree-e2e-matmul-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,9 @@ static void matmul_results_deinitialize(matmul_results_t* results) {
}

// Returns the largest number of characters to print any matrix element.
static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
iree_hal_dim_t cols, iree_hal_dim_t col_start,
iree_hal_dim_t col_end,
static int get_max_elem_width(iree_hal_dim_t rows, iree_hal_dim_t row_start,
iree_hal_dim_t row_end, iree_hal_dim_t cols,
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
iree_hal_element_type_t element_type,
const uint8_t* matrix) {
int max_elem_width = 0;
Expand All @@ -426,15 +425,14 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
// NOTE: iree_max is a macro and may evaluate its args twice.
char buf[64];
int this_elem_width =
iree_test_utils_snprintf_value(buf, sizeof(buf), elem, precision);
iree_test_utils_snprintf_value(buf, sizeof(buf), elem);
max_elem_width = iree_max(max_elem_width, this_elem_width);
}
}
return max_elem_width;
}

// Prints |matrix| to |file|, with |label| as caption.
// |precision| controls how many decimals are printed for float values.
//
// If |other_matrix| is not NULL, then any matrix entries that disagree
// between |matrix| and |other_matrix| (according to
Expand All @@ -451,22 +449,21 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
// characters. According to
// https://www.unicode.org/reports/tr11/#Recommendations, a single emoji
// character should meet that requirement.
static void print_matrix(FILE* file, const char* label, precision_t precision,
iree_hal_dim_t rows, iree_hal_dim_t row_start,
iree_hal_dim_t row_end, iree_hal_dim_t cols,
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
static void print_matrix(FILE* file, const char* label, iree_hal_dim_t rows,
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
iree_hal_dim_t cols, iree_hal_dim_t col_start,
iree_hal_dim_t col_end,
iree_hal_element_type_t element_type,
const uint8_t* matrix, const uint8_t* other_matrix,
const char* highlight) {
IREE_ASSERT((other_matrix == NULL) == (highlight == NULL));
int max_elem_width =
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
col_end, element_type, matrix);
int max_elem_width = get_max_elem_width(
rows, row_start, row_end, cols, col_start, col_end, element_type, matrix);
if (other_matrix) {
// NOTE: iree_max is a macro and may evaluate its args twice.
int other_matrix_max_elem_width =
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
col_end, element_type, other_matrix);
get_max_elem_width(rows, row_start, row_end, cols, col_start, col_end,
element_type, other_matrix);
max_elem_width = iree_max(max_elem_width, other_matrix_max_elem_width);
}

Expand All @@ -489,7 +486,7 @@ static void print_matrix(FILE* file, const char* label, precision_t precision,
!iree_test_utils_result_elements_agree(element, other_element);
}
char buf[64];
iree_test_utils_snprintf_value(buf, sizeof(buf), element, precision);
iree_test_utils_snprintf_value(buf, sizeof(buf), element);
fprintf(file, "%*s", max_elem_width, buf);
// See comment on |highlight| function parameter for why 2 spaces.
// A 3rd space is added unconditionally to make it clear that a highlight
Expand Down Expand Up @@ -523,13 +520,13 @@ static iree_status_t check_matmul_failure(
char actual_value_buf[32];
char expected_value_buf[32];
iree_test_utils_snprintf_value(actual_value_buf, sizeof(actual_value_buf),
actual_value, PRECISION_HIGH);
actual_value);
iree_test_utils_snprintf_value(expected_value_buf, sizeof(expected_value_buf),
expected_value, PRECISION_HIGH);
expected_value);
fprintf(file, "actual value: %s\n", actual_value_buf);
fprintf(file, "expected value: %s\n", expected_value_buf);

iree_hal_dim_t context = 8;
iree_hal_dim_t context = 16;
const char* context_env = getenv("IREE_MATMUL_TEST_SHOW_CONTEXT");
if (context_env) {
if (1 != sscanf(context_env, "%" PRIdim, &context)) {
Expand All @@ -540,39 +537,36 @@ static iree_status_t check_matmul_failure(
}
}
iree_hal_dim_t m_start =
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context);
iree_hal_dim_t m_end = iree_min(results->m, row + context);
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context / 2);
iree_hal_dim_t m_end = iree_min(results->m, m_start + context);
iree_hal_dim_t n_start =
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context);
iree_hal_dim_t n_end = iree_min(results->n, col + context);
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context / 2);
iree_hal_dim_t n_end = iree_min(results->n, n_start + context);
iree_hal_dim_t k_start = 0;
iree_hal_dim_t k_end = iree_min(results->k, 2 * context);
// [k_start, k_end) could be arbitrarily long at this point. Constrain it a
// bit to avoid huge output.
k_end = iree_min(k_end, k_start + 4 * context);
iree_hal_dim_t k_end = iree_min(results->k, context);

fprintf(file, "\n");
print_matrix(file, "left-hand side", PRECISION_LOW, results->m, m_start,
m_end, results->k, k_start, k_end, results->lhs_type,
results->lhs_contents.data, NULL, NULL);
print_matrix(file, "left-hand side", results->m, m_start, m_end, results->k,
k_start, k_end, results->lhs_type, results->lhs_contents.data,
NULL, NULL);
fprintf(file, "\n");
print_matrix(file, "right-hand side", PRECISION_LOW, results->k, k_start,
k_end, results->n, n_start, n_end, results->rhs_type,
results->rhs_contents.data, NULL, NULL);
print_matrix(file, "right-hand side", results->k, k_start, k_end, results->n,
n_start, n_end, results->rhs_type, results->rhs_contents.data,
NULL, NULL);
fprintf(file, "\n");
if (results->acc_contents.data) {
print_matrix(file, "input accumulator", PRECISION_LOW, results->m, m_start,
m_end, results->n, n_start, n_end, results->acc_type,
print_matrix(file, "input accumulator", results->m, m_start, m_end,
results->n, n_start, n_end, results->acc_type,
results->acc_contents.data, NULL, NULL);
fprintf(file, "\n");
}
print_matrix(file, "expected result", PRECISION_LOW, results->m, m_start,
m_end, results->n, n_start, n_end, results->result_type,
print_matrix(file, "expected result", results->m, m_start, m_end, results->n,
n_start, n_end, results->result_type,
results->expected_contents.data, results->actual_contents.data,
iree_test_utils_emoji(true));
fprintf(file, "\n");
print_matrix(file, "actual result", PRECISION_LOW, results->m, m_start, m_end,
results->n, n_start, n_end, results->result_type,
print_matrix(file, "actual result", results->m, m_start, m_end, results->n,
n_start, n_end, results->result_type,
results->actual_contents.data, results->expected_contents.data,
iree_test_utils_emoji(false));
fprintf(file, "\n");
Expand Down
35 changes: 15 additions & 20 deletions tools/testing/e2e/test_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int32_t iree_test_utils_max_elements_to_check(void) {
return FLAG_max_elements_to_check;
}

const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🐞"; }
const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🎃"; }

int iree_test_utils_calculate_check_every(iree_hal_dim_t tot_elements,
iree_hal_dim_t no_div_of) {
Expand Down Expand Up @@ -182,9 +182,13 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
return iree_test_utils_value_make_none();
}

// Important: print all floating point values to FULL precision.
// The audience is debugging low-level numerical bugs.
// Since the values used in most tests are small and integral, these will
// normally print just as concisely, while the extra precision requested here
// will only kick in when it's needed, when there is a numerical bug.
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
iree_test_utils_e2e_value_t value,
precision_t precision) {
iree_test_utils_e2e_value_t value) {
switch (value.type) {
case IREE_TEST_UTILS_VALUE_TYPE_I8:
return snprintf(buf, bufsize, "%" PRIi8, value.i8);
Expand All @@ -195,36 +199,27 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
case IREE_TEST_UTILS_VALUE_TYPE_I64:
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e5m2_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e4m3_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
return snprintf(buf, bufsize, "%.5g",
iree_math_f16_to_f32(value.f16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_BF16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
return snprintf(buf, bufsize, "%.5g",
iree_math_bf16_to_f32(value.bf16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_F32:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32);
return snprintf(buf, bufsize, "%.8g", value.f32);
case IREE_TEST_UTILS_VALUE_TYPE_F64:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.16g" : "%.4g",
value.f64);
return snprintf(buf, bufsize, "%.16g", value.f64);
default:
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled value type"));
Expand Down
9 changes: 1 addition & 8 deletions tools/testing/e2e/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ typedef struct iree_test_utils_value_t {
};
} iree_test_utils_e2e_value_t;

// Enum controlling how many decimals to print floats with.
typedef enum iree_test_utils_precision_e {
PRECISION_LOW,
PRECISION_HIGH,
} precision_t;

// Reads an element from a buffer given index.
iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
iree_hal_dim_t index, iree_hal_element_type_t result_type,
Expand All @@ -90,8 +84,7 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
// Prints a iree_e2e_test_value_t to a string buffer. Returns the number of
// characters written. Like snprintf.
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
iree_test_utils_e2e_value_t value,
precision_t precision);
iree_test_utils_e2e_value_t value);

// Returns true if |expected| and |actual| agree to tolerable accuracy.
bool iree_test_utils_result_elements_agree(iree_test_utils_e2e_value_t expected,
Expand Down

0 comments on commit e10342d

Please sign in to comment.