Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

e2e matmul test improvements #19016

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions tools/testing/e2e/iree-e2e-matmul-test.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive by: I'd still really like to delete all the in-tree code for this test suite and migrate to https://github.com/iree-org/iree-test-suites/tree/main/linalg_ops. Different sets of improvements have been made to both places at this point.

The tests hit a complexity level that I'm not comfortable supporting in-tree, and we should continue transitioning to package-based testing (decoupled from the core project CMake build, runnable from dev/nightly/stable package builds on a wide range of systems) if possible.

Copy link
Contributor Author

@bjacob bjacob Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd Let's chat next week. I really want to help with the test changes you're making and I would like to offer directly taking care of tihs to take this off your plate. But there's something important here -- these tests aren't just "regression" tests that would matter for QA but maybe not so much for day to day development.

These tests are are the center of day to day development, so changes to these tests need to be made as frequently as we add codegen support for a new compilation path, such as when we implement GPU data tiling on another GPU target or supporting another data type.

I think you caught this particular PR because it touches the files under tools/testing/ while the more frequent case with my day-to-day PRs is that they only touch the code under tests/e2e/matmul/. The latter is really changing at a high frequency (see git history). The former, a bit less high frequency, but still, just today I noticed I needed to make changes there for f64, #19093, for instance (which I need to test F64 MFMA on CDNA3, which I need for completeness because if the hw supports f64, it's easier to just support it than to refrain from it even if no immediate need).

That doesn't preclude moving these tests out of tree but that does mean that:

  1. The out-of-tree move needs to be atomic --- if it means I can't land test changes concurrently, that's literally blocking my work. If other test changes/improvements are planned together with the move, it's really important to keep them separate from the move itself, so the atomic op here can have low latency.
  2. The out-of-tree move will incur a permanent penalty on my development velocity. I think I'm OK with that because I really care about test infrastructure health and trust your judgement here, but that's part of why I'd like to be involved (the other part being that I really want to help with that!).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still thinking through some of the design questions here, I don't have a final state in mind that I'm totally happy with. Here are some responses to your individual points.


These tests are are the center of day to day development, so changes to these tests need to be made as frequently as we add codegen support for a new compilation path, such as when we implement GPU data tiling on another GPU target or supporting another data type.

It's my (perhaps naive) hope that we could somewhat exhaustively enumerate the data types we could support in the test suite, independent of what we currently do support on any given target. I'd like to get out of the habit of only testing what we already know works. For that I want to support XFAIL tests, but I had trouble building that into the in-tree tests. To be fair... the out of tree tests don't yet have XFAIL either, and I'm not sure if CMake/ctest is the right framework for that level of configuration. I've been much happier with pytest for handling test filtering and expected outcomes.

As for things like feature flag flips and choosing compilation pipelines, I like that putting an air gap between the compiler+runtime and the tests keeps us honest about default compiler behavior and user-friendly optimization settings. The air gap / github repository boundary does introduce friction for daily development though. I can think of a few ways we can mitigate that friction, but any of them would require some changes to developer workflows:

  • Include iree-test-suites as a subproject using CMake FetchContent (not a git submodule, so random users cloning the repo don't pay the cost for test suites)
  • Add scripts that go through common steps like checking out the test suite repo, building the tools, running tests, getting test results, collecting benchmarks, etc.
    • Having the tests in their own project gives us more freedom here to deviate from IREE's core CMake setup

RE: atomic changes, I started https://github.com/iree-org/iree-test-suites/tree/main/linalg_ops as a relatively simple fork at first and then started making deeper changes to take advantage of the new decoupling:

Leaving things half migrated is unfortunate, but my time was pulled elsewhere. @erman-gurses has also been helping get convolution and attention tests moved into the test suite repo recently. If we come to a decision we are all happy with then we can schedule some time on the roadmap to finish the migration work.

Copy link
Contributor Author

@bjacob bjacob Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details! Let's find time this week for a video call. For now just recording a couple data points. One is:

~/iree-build$ find . -name '*e2e_matmul*mlir' -printf "%s\n" | awk '{sum+=$1} END {print sum}'
26158552

Another data point: pace of introduction of new element types in CDNA architectures:

  1. CDNA1: f32, f16, i8
  2. CDNA2: f64, bf16
  3. CDNA3: xf32, f8E5M2FNUZ, f8E4M3FNUZ
  4. Near future: f8E5M2 (non-FNUZ), f8E4M3 (non-FNUZ), FP6, FP4, INT6, INT4, presumably all what's being standardized in Microscaling 1.0.
  5. Meanwhile, NVIDIA architectures support some of these types already in production, as well as altogether different types such as TF32.

Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,9 @@ static void matmul_results_deinitialize(matmul_results_t* results) {
}

// Returns the largest number of characters to print any matrix element.
static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
iree_hal_dim_t cols, iree_hal_dim_t col_start,
iree_hal_dim_t col_end,
static int get_max_elem_width(iree_hal_dim_t rows, iree_hal_dim_t row_start,
iree_hal_dim_t row_end, iree_hal_dim_t cols,
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
iree_hal_element_type_t element_type,
const uint8_t* matrix) {
int max_elem_width = 0;
Expand All @@ -428,15 +427,14 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
// NOTE: iree_max is a macro and may evaluate its args twice.
char buf[64];
int this_elem_width =
iree_test_utils_snprintf_value(buf, sizeof(buf), elem, precision);
iree_test_utils_snprintf_value(buf, sizeof(buf), elem);
max_elem_width = iree_max(max_elem_width, this_elem_width);
}
}
return max_elem_width;
}

// Prints |matrix| to |file|, with |label| as caption.
// |precision| controls how many decimals are printed for float values.
//
// If |other_matrix| is not NULL, then any matrix entries that disagree
// between |matrix| and |other_matrix| (according to
Expand All @@ -453,22 +451,21 @@ static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows,
// characters. According to
// https://www.unicode.org/reports/tr11/#Recommendations, a single emoji
// character should meet that requirement.
static void print_matrix(FILE* file, const char* label, precision_t precision,
iree_hal_dim_t rows, iree_hal_dim_t row_start,
iree_hal_dim_t row_end, iree_hal_dim_t cols,
iree_hal_dim_t col_start, iree_hal_dim_t col_end,
static void print_matrix(FILE* file, const char* label, iree_hal_dim_t rows,
iree_hal_dim_t row_start, iree_hal_dim_t row_end,
iree_hal_dim_t cols, iree_hal_dim_t col_start,
iree_hal_dim_t col_end,
iree_hal_element_type_t element_type,
const uint8_t* matrix, const uint8_t* other_matrix,
const char* highlight) {
IREE_ASSERT((other_matrix == NULL) == (highlight == NULL));
int max_elem_width =
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
col_end, element_type, matrix);
int max_elem_width = get_max_elem_width(
rows, row_start, row_end, cols, col_start, col_end, element_type, matrix);
if (other_matrix) {
// NOTE: iree_max is a macro and may evaluate its args twice.
int other_matrix_max_elem_width =
get_max_elem_width(precision, rows, row_start, row_end, cols, col_start,
col_end, element_type, other_matrix);
get_max_elem_width(rows, row_start, row_end, cols, col_start, col_end,
element_type, other_matrix);
max_elem_width = iree_max(max_elem_width, other_matrix_max_elem_width);
}

Expand All @@ -491,7 +488,7 @@ static void print_matrix(FILE* file, const char* label, precision_t precision,
!iree_test_utils_result_elements_agree(element, other_element);
}
char buf[64];
iree_test_utils_snprintf_value(buf, sizeof(buf), element, precision);
iree_test_utils_snprintf_value(buf, sizeof(buf), element);
fprintf(file, "%*s", max_elem_width, buf);
// See comment on |highlight| function parameter for why 2 spaces.
// A 3rd space is added unconditionally to make it clear that a highlight
Expand Down Expand Up @@ -525,13 +522,13 @@ static iree_status_t check_matmul_failure(
char actual_value_buf[32];
char expected_value_buf[32];
iree_test_utils_snprintf_value(actual_value_buf, sizeof(actual_value_buf),
actual_value, PRECISION_HIGH);
actual_value);
iree_test_utils_snprintf_value(expected_value_buf, sizeof(expected_value_buf),
expected_value, PRECISION_HIGH);
expected_value);
fprintf(file, "actual value: %s\n", actual_value_buf);
fprintf(file, "expected value: %s\n", expected_value_buf);

iree_hal_dim_t context = 8;
iree_hal_dim_t context = 16;
const char* context_env = getenv("IREE_MATMUL_TEST_SHOW_CONTEXT");
if (context_env) {
if (1 != sscanf(context_env, "%" PRIdim, &context)) {
Expand All @@ -542,39 +539,36 @@ static iree_status_t check_matmul_failure(
}
}
iree_hal_dim_t m_start =
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context);
iree_hal_dim_t m_end = iree_min(results->m, row + context);
(iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context / 2);
iree_hal_dim_t m_end = iree_min(results->m, m_start + context);
iree_hal_dim_t n_start =
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context);
iree_hal_dim_t n_end = iree_min(results->n, col + context);
(iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context / 2);
iree_hal_dim_t n_end = iree_min(results->n, n_start + context);
iree_hal_dim_t k_start = 0;
iree_hal_dim_t k_end = iree_min(results->k, 2 * context);
// [k_start, k_end) could be arbitrarily long at this point. Constrain it a
// bit to avoid huge output.
k_end = iree_min(k_end, k_start + 4 * context);
iree_hal_dim_t k_end = iree_min(results->k, context);

fprintf(file, "\n");
print_matrix(file, "left-hand side", PRECISION_LOW, results->m, m_start,
m_end, results->k, k_start, k_end, results->lhs_type,
results->lhs_contents.data, NULL, NULL);
print_matrix(file, "left-hand side", results->m, m_start, m_end, results->k,
k_start, k_end, results->lhs_type, results->lhs_contents.data,
NULL, NULL);
fprintf(file, "\n");
print_matrix(file, "right-hand side", PRECISION_LOW, results->k, k_start,
k_end, results->n, n_start, n_end, results->rhs_type,
results->rhs_contents.data, NULL, NULL);
print_matrix(file, "right-hand side", results->k, k_start, k_end, results->n,
n_start, n_end, results->rhs_type, results->rhs_contents.data,
NULL, NULL);
fprintf(file, "\n");
if (results->acc_contents.data) {
print_matrix(file, "input accumulator", PRECISION_LOW, results->m, m_start,
m_end, results->n, n_start, n_end, results->acc_type,
print_matrix(file, "input accumulator", results->m, m_start, m_end,
results->n, n_start, n_end, results->acc_type,
results->acc_contents.data, NULL, NULL);
fprintf(file, "\n");
}
print_matrix(file, "expected result", PRECISION_LOW, results->m, m_start,
m_end, results->n, n_start, n_end, results->result_type,
print_matrix(file, "expected result", results->m, m_start, m_end, results->n,
n_start, n_end, results->result_type,
results->expected_contents.data, results->actual_contents.data,
iree_test_utils_emoji(true));
fprintf(file, "\n");
print_matrix(file, "actual result", PRECISION_LOW, results->m, m_start, m_end,
results->n, n_start, n_end, results->result_type,
print_matrix(file, "actual result", results->m, m_start, m_end, results->n,
n_start, n_end, results->result_type,
results->actual_contents.data, results->expected_contents.data,
iree_test_utils_emoji(false));
fprintf(file, "\n");
Expand Down
35 changes: 15 additions & 20 deletions tools/testing/e2e/test_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int32_t iree_test_utils_max_elements_to_check(void) {
return FLAG_max_elements_to_check;
}

const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🐞"; }
const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🎃"; }

int iree_test_utils_calculate_check_every(iree_hal_dim_t tot_elements,
iree_hal_dim_t no_div_of) {
Expand Down Expand Up @@ -173,9 +173,13 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
return iree_test_utils_value_make_none();
}

// Important: print all floating point values to FULL precision.
// The audience is debugging low-level numerical bugs.
// Since the values used in most tests are small and integral, these will
// normally print just as concisely, while the extra precision requested here
// will only kick in when it's needed, when there is a numerical bug.
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
iree_test_utils_e2e_value_t value,
precision_t precision) {
iree_test_utils_e2e_value_t value) {
switch (value.type) {
case IREE_TEST_UTILS_VALUE_TYPE_I8:
return snprintf(buf, bufsize, "%" PRIi8, value.i8);
Expand All @@ -186,36 +190,27 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
case IREE_TEST_UTILS_VALUE_TYPE_I64:
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e5m2_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e4m3_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
return snprintf(buf, bufsize, "%.3g",
iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
return snprintf(buf, bufsize, "%.5g",
iree_math_f16_to_f32(value.f16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_BF16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
return snprintf(buf, bufsize, "%.5g",
iree_math_bf16_to_f32(value.bf16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_F32:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32);
return snprintf(buf, bufsize, "%.8g", value.f32);
case IREE_TEST_UTILS_VALUE_TYPE_F64:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.16g" : "%.4g",
value.f64);
return snprintf(buf, bufsize, "%.16g", value.f64);
default:
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled value type"));
Expand Down
9 changes: 1 addition & 8 deletions tools/testing/e2e/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ typedef struct iree_test_utils_value_t {
};
} iree_test_utils_e2e_value_t;

// Enum controlling how many decimals to print floats with.
typedef enum iree_test_utils_precision_e {
PRECISION_LOW,
PRECISION_HIGH,
} precision_t;

// Reads an element from a buffer given index.
iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
iree_hal_dim_t index, iree_hal_element_type_t result_type,
Expand All @@ -90,8 +84,7 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
// Prints a iree_e2e_test_value_t to a string buffer. Returns the number of
// characters written. Like snprintf.
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
iree_test_utils_e2e_value_t value,
precision_t precision);
iree_test_utils_e2e_value_t value);

// Returns true if |expected| and |actual| agree to tolerable accuracy.
bool iree_test_utils_result_elements_agree(iree_test_utils_e2e_value_t expected,
Expand Down
Loading