diff --git a/tests/e2e/attention/CMakeLists.txt b/tests/e2e/attention/CMakeLists.txt new file mode 100644 index 000000000000..805b9fbe0240 --- /dev/null +++ b/tests/e2e/attention/CMakeLists.txt @@ -0,0 +1,80 @@ +# TODO: (#17751) Add the arm_64 tests when the bug resolved. See: +# https://github.com/iree-org/iree/actions/runs/10468944505/job/28990909321#step:4:9815 +if(IREE_ARCH STREQUAL "arm_64") + return() +endif() + +iree_generated_e2e_runner_test( + NAME + e2e_attention_cpu_f16_f16_f16_small + TEST_TYPE + attention + GENERATOR + "generate_e2e_attention_tests.py" + GENERATOR_ARGS + "--query_type=f16" + "--key_type=f16" + "--value_type=f16" + "--shapes=small" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-attention-test + TARGET_BACKENDS + "llvm-cpu" + DRIVERS + "local-task" + LABELS + "hostonly" + "local" + TARGET_CPU_FEATURES_VARIANTS + "default" +) + +iree_generated_e2e_runner_test( + NAME + e2e_attention_cpu_f16_f16_f16_medium + TEST_TYPE + attention + GENERATOR + "generate_e2e_attention_tests.py" + GENERATOR_ARGS + "--query_type=f16" + "--key_type=f16" + "--value_type=f16" + "--shapes=medium" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-attention-test + TARGET_BACKENDS + "llvm-cpu" + DRIVERS + "local-task" + LABELS + "hostonly" + "local" + TARGET_CPU_FEATURES_VARIANTS + "default" +) + +iree_generated_e2e_runner_test( + NAME + e2e_attention_cpu_f16_f16_f16_large + TEST_TYPE + attention + GENERATOR + "generate_e2e_attention_tests.py" + GENERATOR_ARGS + "--query_type=f16" + "--key_type=f16" + "--value_type=f16" + "--shapes=large" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-attention-test + TARGET_BACKENDS + "llvm-cpu" + DRIVERS + "local-task" + LABELS + "hostonly" + "local" + TARGET_CPU_FEATURES_VARIANTS + "default" +) diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py new file mode 100644 index 000000000000..8af76f03079d --- /dev/null +++ b/tests/e2e/attention/generate_e2e_attention_tests.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Generator for e2e attention tests. +""" + +import argparse +import enum +import dataclasses +import typing +import math + + +# Data type of kernel entries. The string values must match MLIR data types. +@enum.unique +class QueryElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class KeyElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class ValueElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class ResultElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Enumerates of the collections of shapes that we can generate tests for. +# The values are the accepted values for the --shapes= flag. +@enum.unique +class ShapesId(enum.Enum): + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + + +# batch: Batch dimension +# m: M dimension of first and second matmul +# n: N dimension of second matmul +# k1: K dimension of first matmul +# k2: K dimension of second matmul +@dataclasses.dataclass +class TestShapeAndScale: + batch: int + m: int + k1: int + k2: int + n: int + scale: float + + +# Returns the list of TestShape's to use for the collection of shapes +# identified by shapes_id. +def get_test_shapes(shapes_id: ShapesId): + if shapes_id == ShapesId.SMALL: + return [ + TestShapeAndScale(batch=2, m=256, k1=64, k2=32, n=16, scale=1.0), + ] + if shapes_id == ShapesId.MEDIUM: + return [ + TestShapeAndScale(batch=2, m=512, k1=128, k2=64, n=32, scale=1.0), + ] + if shapes_id == ShapesId.LARGE: + return [ + TestShapeAndScale(batch=2, m=1024, k1=256, k2=128, n=64, scale=1.0), + ] + + raise ValueError(shapes_id) + + +# Determines the shape of input and kernel tensors. +@dataclasses.dataclass +class TestInputTensorShapes: + batch: int + m: int + k1: int + k2: int + n: int + scale: float + + +# Helper for generate_function. Generates TestInputTensorShapes, i.e. +# converts from the runtime shape dimensions in TestShape and given dynamicity to +# the set of shapes to be used in a test function's input tensors. +def generate_shapes_and_scale(shape: TestShapeAndScale): + batch = shape.batch + m = shape.m + k1 = shape.k1 + k2 = shape.k2 + n = shape.n + scale = shape.scale + + shapes_scale = TestInputTensorShapes( + batch=batch, + m=m, + k1=k1, + k2=k2, + n=n, + scale=scale, + ) + return shapes_scale + + +# Helper to return input, kernel and output shapes based on the layout and the Attention Params. +def get_tensor_shapes( + shapes_scale: TestShapeAndScale, +): + batch = shapes_scale.batch + m = shapes_scale.m + k1 = shapes_scale.k1 + k2 = shapes_scale.k2 + n = shapes_scale.n + scale = shapes_scale.scale + + query_tensor_shape = [batch, m, k1] + key_tensor_shape = [batch, k2, k1] + value_tensor_shape = [batch, k2, n] + result_tensor_shape = [batch, m, n] + + return query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape + + +# Helper for generate_function. +# Generates a name for a test function in the generated MLIR code. +def generate_function_name( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_scale: TestInputTensorShapes, +): + query_t = query_type.value + key_t = key_type.value + value_t = value_type.value + result_t = value_type.value + + batch = shapes_scale.batch + m = shapes_scale.m + k1 = shapes_scale.k1 + k2 = shapes_scale.k2 + n = shapes_scale.n + + attention = "attention" + return ( + f"{attention}_{batch}_{m}_{k1}_{k2}_{n}" + + f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}" + ) + + +# Represents a generated test function. +@dataclasses.dataclass +class MLIRFunction: + name: str + signature: str + import_declaration: str + definition: str + + +# Generates a test function in the generated MLIR code. +# The generated function will take the same arguments as iree_linalg_ext.attention variants +# and will just call iree_linalg_ext.attention variants with them, returning its result. +def generate_function( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shape_scale: TestShapeAndScale, +): + shapes_scale = generate_shapes_and_scale(shape_scale) + func_name = generate_function_name( + query_type, + key_type, + value_type, + shapes_scale, + ) + + query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(shapes_scale) + query_tensor_type = ( + f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>" + ) + key_tensor_type = ( + f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>" + ) + value_tensor_type = ( + f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>" + ) + result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>" + F32 = "f32" + F16 = "f16" + op_name = "iree_linalg_ext.attention" + + # Compilation info is optional; prints empty string by default. + func_definition = "" + + signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}" + import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view" + func_definition = func_definition + ( + f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n" + f" %result0 = tensor.empty(): {result_tensor_type}\n" + f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n" + f" %result1 = {op_name} {{\n" + f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}" + f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n" + f" outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n" + f" return %result1: {result_tensor_type}\n" + f"}}\n" + ) + return MLIRFunction( + name=func_name, + signature=signature, + import_declaration=import_declaration, + definition=func_definition, + ) + + +# Represents a call to a generated test function. +@dataclasses.dataclass +class TestCall: + function: MLIRFunction + op: str + + +# Enumerates ways to initialize tensor buffer contents. +@enum.unique +class TensorGenerator(enum.Enum): + ZERO = "zero" # Fill with zeros + RANDOM = "random" # Fill with (deterministic) pseudorandom values. + + +# Intentionally fixed seed! We want full reproducibility here, both across runs +# and across machines. +# Intentionally not shared with local_pseudorandom_state to limit the ways +# in which shuffling testcases changes which random values are generated. +pseudorandom_generator_seed = 1 + + +def contents_generator_tag(generator: TensorGenerator): + if generator == TensorGenerator.ZERO: + return "" + elif generator == TensorGenerator.RANDOM: + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed + 1 + return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}" + else: + raise ValueError(generator) + + +# Generate a 3d tensor function argument of the given size as `%name`. +def generate_random_3d_tensor( + name: str, + tensor_shape: list, + element_type: typing.Union[QueryElemTypeId, ResultElemTypeId], +): + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed + 1 + return ( + f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n" + f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n" + f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n" + f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n" + f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n" + f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n" + ) + + +call_id = 0 + + +def generate_call( + function: MLIRFunction, + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_scale: TestShapeAndScale, +): + global call_id + func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}" + func_name = f"{func_name}_{call_id}" + call_id = call_id + 1 + + description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}" + op = ( + f"func.func @{func_name}() attributes {{\n" + f' iree.reflection = {{description = "{description}"}}\n' + "} {\n" + " %device_index = arith.constant 0 : index\n" + " %device = hal.devices.get %device_index : !hal.device\n" + ) + + query_shape, key_shape, value_shape, result_shape = get_tensor_shapes( + shapes_scale, + ) + + op = op + generate_random_3d_tensor("query", query_shape, query_type) + op = op + generate_random_3d_tensor("key", key_shape, key_type) + op = op + generate_random_3d_tensor("value", value_shape, value_type) + + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed - 1 + op = op + ( + f" %scale = arith.constant {shapes_scale.scale} : f32\n" + f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n" + ) + + op = op + ( + f" %batch = arith.constant {shapes_scale.batch} : i64 \n" + f" %m = arith.constant {shapes_scale.m} : i64 \n" + f" %k1 = arith.constant {shapes_scale.k1} : i64 \n" + f" %k2 = arith.constant {shapes_scale.k2} : i64 \n" + f" %n = arith.constant {shapes_scale.n} : i64 \n" + f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n" + f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n" + f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n" + f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n" + f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n" + f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n" + f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n" + f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n" + f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" + f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" + f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" + f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" + f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" + ) + + op = op + " return\n" + op = op + "}\n" + + return TestCall(function=function, op=op) + + +# Generates all output files' contents as strings. +def generate( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_id: ShapesId, +): + functions = {} + calls = [] + + for shape in get_test_shapes(shapes_id): + function = generate_function( + query_type, + key_type, + value_type, + shape, + ) + if function.name not in functions: + functions[function.name] = function + calls.append( + generate_call( + function, + query_type, + key_type, + value_type, + shape, + ) + ) + + return (functions, calls) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Generator of e2e Attention tests") + parser.add_argument( + "--output_attention_mlir", + type=str, + help="Path of output .mlir file containing the generated Attention functions", + required=True, + ) + parser.add_argument( + "--output_calls_mlir", + type=str, + help="Path of output .mlir file containing the calls", + required=True, + ) + parser.add_argument( + "--query_type", + type=str, + choices=["f16"], + help="Numeric type of query tensors ", + required=True, + ) + parser.add_argument( + "--key_type", + type=str, + choices=["f16"], + help="Numeric type of key tensors ", + required=True, + ) + parser.add_argument( + "--value_type", + type=str, + choices=["f16"], + help="Numeric type of value tensors ", + required=True, + ) + parser.add_argument( + "--shapes_scale", + type=str, + choices=[s.value for s in ShapesId], + help="Collection of tensor shapes to test", + required=True, + ) + parser.add_argument( + "--requirements", + type=str, + help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.", + required=False, + ) + return parser.parse_args() + + +def write_code_file(functions, filename): + with open(filename, "w") as file: + for function in functions.values(): + file.write(function.definition + "\n") + + +def write_calls_file(functions, calls, filename, requirements): + # Module-level reflection information used to control the test tool. + reflection = "" + if requirements: + reflection = ( + "iree.reflection = {" + 'target_features = "' + + ",".join([req.lstrip("+") for req in requirements.split(",")]) + + '"' + "}" + ) + module_definition = ( + f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n" + ) + + # Declare the custom module that generates arguments. + module_definition = module_definition + ( + "func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n" + "func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n" + "\n" + ) + + # Declare the functions that will be called. + for function in functions.values(): + module_definition = module_definition + function.import_declaration + "\n" + module_definition = module_definition + "\n" + + # Emit the test cases for each call. + for call in calls: + module_definition = module_definition + call.op + "\n" + + module_definition = module_definition + "\n}\n" + + with open(filename, "w") as file: + file.write(module_definition) + + +def main(args): + query_type = QueryElemTypeId(args.query_type) + key_type = KeyElemTypeId(args.key_type) + value_type = ValueElemTypeId(args.value_type) + shapes_id = ShapesId(args.shapes_scale) + + (functions, calls) = generate( + query_type, + key_type, + value_type, + shapes_id, + ) + + write_code_file(functions, args.output_attention_mlir) + write_calls_file( + functions, + calls, + args.output_calls_mlir, + args.requirements, + ) + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/tools/testing/e2e/BUILD.bazel b/tools/testing/e2e/BUILD.bazel index 0c510a9e0ba6..397627961d20 100644 --- a/tools/testing/e2e/BUILD.bazel +++ b/tools/testing/e2e/BUILD.bazel @@ -68,3 +68,22 @@ iree_runtime_cc_binary( "//runtime/src/iree/vm:cc", ], ) + +iree_runtime_cc_binary( + name = "iree-e2e-attention-test", + srcs = ["iree-e2e-attention-test.cc"], + deps = [ + ":e2e_test_util", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", + "//runtime/src/iree/base/internal:cpu", + "//runtime/src/iree/base/internal:flags", + "//runtime/src/iree/base/internal:path", + "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/tooling:context_util", + "//runtime/src/iree/tooling:device_util", + "//runtime/src/iree/vm", + "//runtime/src/iree/vm:cc", + ], +) diff --git a/tools/testing/e2e/CMakeLists.txt b/tools/testing/e2e/CMakeLists.txt index e4fc8fbf864e..ece0c59d00b0 100644 --- a/tools/testing/e2e/CMakeLists.txt +++ b/tools/testing/e2e/CMakeLists.txt @@ -77,4 +77,24 @@ iree_cc_binary( iree::vm::cc ) +iree_cc_binary( + NAME + iree-e2e-attention-test + SRCS + "iree-e2e-attention-test.cc" + DEPS + ::e2e_test_util + iree::base + iree::base::internal + iree::base::internal::cpu + iree::base::internal::flags + iree::base::internal::path + iree::hal + iree::modules::hal + iree::tooling::context_util + iree::tooling::device_util + iree::vm + iree::vm::cc +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/tools/testing/e2e/iree-e2e-attention-test.cc b/tools/testing/e2e/iree-e2e-attention-test.cc new file mode 100644 index 000000000000..4b0464b13dfb --- /dev/null +++ b/tools/testing/e2e/iree-e2e-attention-test.cc @@ -0,0 +1,486 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/base/internal/cpu.h" +#include "iree/base/internal/flags.h" +#include "iree/base/internal/math.h" +#include "iree/base/internal/path.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/context_util.h" +#include "iree/tooling/device_util.h" +#include "iree/vm/api.h" +#include "iree/vm/native_module_cc.h" +#include "tools/testing/e2e/test_utils.h" + +//===----------------------------------------------------------------------===// +// Reference Attention +//===----------------------------------------------------------------------===// + +// Helper for reference_attention. +// Function to allocate and initialize tensors +float* allocate_tensor(int dim1, int dim2, int dim3) { + const int size = dim1 * dim2 * dim3; + float* tensor = (float*)malloc(size * sizeof(float)); + for (int i = 0; i < size; ++i) { + tensor[i] = 0.0f; + } + return tensor; +} + +// Function to free allocated tensors +void free_tensor(float* tensor) { + if (tensor != nullptr) free(tensor); +} + +// Function to calculate 1D index for a 3D array +int index_3d(int i, int j, int k, int dim2, int dim3) { + return i * dim2 * dim3 + j * dim3 + k; +} + +static void reference_attention_f32_f32_f32_f32( + iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, + iree_hal_dim_t B, const float* query_data, const float* key_data, + const float* value_data, float* result_data, iree_hal_dim_t b, + float* Attention) { + // Compute Q * K^T + for (int m = 0; m < M; ++m) { + for (int k2 = 0; k2 < K2; ++k2) { + float sum = 0.0; + for (int k1 = 0; k1 < K1; ++k1) { + int q_idx = index_3d(b, m, k1, M, K1); + int k_idx = index_3d(b, k2, k1, K2, K1); + + sum += query_data[q_idx] * key_data[k_idx]; + } + int att_idx = index_3d(0, m, k2, M, K2); + Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1) + } + } + + // Compute softmax on Attention + for (int m = 0; m < M; ++m) { + // Find the maximum value for the current sequence + float max_val = -FLT_MAX; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + max_val = iree_max(max_val, Attention[att_idx]); + } + + // Calculate the softmax denominator + float sum = 0.0f; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + sum += exp(Attention[att_idx] - max_val); + } + + // Apply softmax + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + Attention[att_idx] = exp(Attention[att_idx]) / sum; + } + } + + // Compute Attention * V + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + int v_idx = index_3d(b, k2, n, K2, N); + sum += Attention[att_idx] * value_data[v_idx]; + } + int o_idx = index_3d(b, m, n, M, N); + result_data[o_idx] = sum; + } + } +} + +static iree_status_t reference_attention_element( + iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, + iree_hal_dim_t B, iree_hal_element_type_t query_elem_type, + iree_hal_element_type_t key_elem_type, + iree_hal_element_type_t value_elem_type, void* query_data, void* key_data, + void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b, + float* Attention) { + if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && + key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && + value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_attention_f32_f32_f32_f32( + M, K1, K2, N, B, (const float*)query_data, (const float*)key_data, + (const float*)value_data, (float*)result_data, b, Attention); + + } else { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "unhandled combination of element types in attention"); + } + return iree_ok_status(); +} + +// Reference attention implementation, used to compare attention results +// against. +static iree_status_t reference_attention( + iree_hal_dim_t B, iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, + iree_hal_dim_t N, iree_hal_element_type_t query_elem_type, + iree_hal_element_type_t key_elem_type, + iree_hal_element_type_t value_elem_type, iree_byte_span_t query_contents, + iree_byte_span_t key_contents, iree_byte_span_t value_contents, + iree_byte_span_t actual_contents, iree_byte_span_t result_contents, + int compute_every) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, B); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, M); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K1); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K2); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N); + + iree_host_size_t count = 0; + float* Attention = allocate_tensor(1, M, K2); + for (iree_hal_dim_t b = 0; b < B; ++b) { + if (++count < compute_every) continue; + count = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + reference_attention_element( + M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type, + query_contents.data, key_contents.data, value_contents.data, + actual_contents.data, result_contents.data, b, Attention)); + } + free_tensor(Attention); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} +//===----------------------------------------------------------------------===// +// Attention comparison/logging +//===----------------------------------------------------------------------===// + +typedef struct { + iree_allocator_t host_allocator; + iree_hal_dim_t b; + iree_hal_dim_t m; + iree_hal_dim_t k1; + iree_hal_dim_t k2; + iree_hal_dim_t n; + iree_hal_element_type_t query_elem_type; + iree_hal_element_type_t key_elem_type; + iree_hal_element_type_t value_elem_type; + iree_hal_element_type_t result_elem_type; + iree_byte_span_t query_contents; + iree_byte_span_t key_contents; + iree_byte_span_t value_contents; + iree_byte_span_t actual_contents; + iree_byte_span_t expected_contents; +} attention_results_t; + +static void attention_results_deinitialize(attention_results_t* results); + +static iree_status_t attention_results_initialize( + iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, + iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size, + iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key, + iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result, + iree_allocator_t host_allocator, attention_results_t* out_results) { + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_results, 0, sizeof(*out_results)); + out_results->host_allocator = host_allocator; + + out_results->b = b_size; + out_results->m = m_size; + out_results->k1 = k1_size; + out_results->k2 = k2_size; + out_results->n = n_size; + + out_results->query_elem_type = iree_hal_buffer_view_element_type(query); + out_results->key_elem_type = iree_hal_buffer_view_element_type(key); + out_results->value_elem_type = iree_hal_buffer_view_element_type(value); + out_results->result_elem_type = iree_hal_buffer_view_element_type(result); + + iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query); + iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key); + iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value); + iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result); + + iree_status_t status = iree_ok_status(); + + if (iree_status_is_ok(status)) { + out_results->query_contents.data_length = + iree_hal_buffer_byte_length(query_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->query_contents.data_length, + (void**)&out_results->query_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, query_buffer, 0, out_results->query_contents.data, + out_results->query_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->key_contents.data_length = + iree_hal_buffer_byte_length(key_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->key_contents.data_length, + (void**)&out_results->key_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, key_buffer, 0, out_results->key_contents.data, + out_results->key_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->value_contents.data_length = + iree_hal_buffer_byte_length(value_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->value_contents.data_length, + (void**)&out_results->value_contents.data); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, value_buffer, 0, out_results->value_contents.data, + out_results->value_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->actual_contents.data_length = + iree_hal_buffer_byte_length(result_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->actual_contents.data_length, + (void**)&out_results->actual_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, result_buffer, 0, out_results->actual_contents.data, + out_results->actual_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->expected_contents.data_length = + iree_hal_buffer_byte_length(result_buffer); + status = iree_allocator_malloc( + host_allocator, out_results->expected_contents.data_length, + (void**)&out_results->expected_contents.data); + } + if (!iree_status_is_ok(status)) { + attention_results_deinitialize(out_results); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void attention_results_deinitialize(attention_results_t* results) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_allocator_free(results->host_allocator, results->query_contents.data); + iree_allocator_free(results->host_allocator, results->key_contents.data); + iree_allocator_free(results->host_allocator, results->value_contents.data); + iree_allocator_free(results->host_allocator, results->actual_contents.data); + iree_allocator_free(results->host_allocator, results->expected_contents.data); + + IREE_TRACE_ZONE_END(z0); +} + +// Helper for check_attention_results: the actual interesting part once we've +// obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed +// logging is written to |file| if it is not NULL. +static iree_status_t check_attention_results_impl( + FILE* file, const attention_results_t* results, int check_every) { + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, reference_attention(results->b, results->m, results->k1, results->k2, + results->n, results->query_elem_type, + results->key_elem_type, results->value_elem_type, + results->query_contents, results->key_contents, + results->value_contents, results->actual_contents, + results->expected_contents, check_every)); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Given an actual attention's inputs and output (all host-local), uses a +// reference attention implementation on the same inputs to check if the output +// is correct. On error, detailed logging is written to |file| if it is not +// NULL. +static iree_status_t check_attention_results( + FILE* file, const attention_results_t* results) { + IREE_TRACE_ZONE_BEGIN(z0); + // TODO: Increase the check every param to reduce the number of comparisons. + int check_every = 1; + iree_status_t status = + check_attention_results_impl(file, results, check_every); + if (!iree_status_is_ok(status) && check_every > 1) { + // If we got a failure with check_every>1, that didn't log a useful + // numerical summary, as most of the reference matrix entries hadn't been + // computed. Rerun now with check_every=1 to get that numerical logging. + iree_status_ignore(status); + status = check_attention_results_impl(file, results, 1); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// `attention_test` custom module +//===----------------------------------------------------------------------===// +// This uses the C++ wrapper to keep things simple. Though easier to use it's +// got additional overhead/code-size bloat that doesn't matter in a test like +// this. Making a C module builder API that removes the boilerplate there is TBD +// so this file is written in C besides this module so that we can swap it back +// to being pure C in the future. + +namespace iree { + +class AttentionTestModuleState final { + public: + explicit AttentionTestModuleState(iree_allocator_t host_allocator) + : host_allocator_(host_allocator) {} + ~AttentionTestModuleState() = default; + + // Fills the destination span with pseudorandom values of the given + // |element_type|. The given |seed| is passed to the pseudorandom generator. + // The pseudorandom values are reproducible both across runs and across + // machines. + StatusOr> GenerateRandom3dTensor( + const vm::ref device, int64_t dim0, int64_t dim1, + int64_t dim2, iree_hal_element_type_t element_type, int32_t seed) { + iree_hal_dim_t dims[3] = { + (iree_hal_dim_t)dim0, + (iree_hal_dim_t)dim1, + (iree_hal_dim_t)dim2, + }; + iree_hal_buffer_params_t buffer_params = {0}; + buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; + buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL; + buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + vm::ref result_view; + struct callback_state_t { + iree_hal_element_type_t element_type; + int32_t seed; + } callback_state = { + element_type, + seed, + }; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer( + device.get(), iree_hal_device_allocator(device.get()), + IREE_ARRAYSIZE(dims), dims, element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, + +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { + callback_state_t callback_state = *(callback_state_t*)user_data; + iree_byte_span_t span = mapping->contents; + // Generate "uniform" integer-valued numbers in the range [min, max]. + int32_t min = 0; + int32_t max = 0; + iree_test_utils_get_min_max_for_element_type( + callback_state.element_type, &min, &max); + uint32_t range = (max - min + 1); + iree_host_size_t element_byte_count = + iree_hal_element_dense_byte_count(callback_state.element_type); + uint8_t* data_end = span.data + span.data_length; + uint32_t state = callback_state.seed; + for (uint8_t* data = span.data; data < data_end; + data += element_byte_count) { + int32_t value = + (int32_t)iree_test_utils_pseudorandom_range(&state, range) + + min; + iree_test_utils_write_element(callback_state.element_type, value, + data); + } + return iree_ok_status(); + }, + &callback_state, &result_view)); + return std::move(result_view); + } + + Status CheckAttentionResults( + const vm::ref device, int64_t b, int64_t m, int64_t k1, + int64_t k2, int64_t n, const vm::ref query, + const vm::ref key, + const vm::ref value, + const vm::ref actual_result) { + attention_results_t results = {}; + IREE_RETURN_IF_ERROR(attention_results_initialize( + device.get(), (iree_hal_dim_t)b, (iree_hal_dim_t)m, (iree_hal_dim_t)k1, + (iree_hal_dim_t)k2, (iree_hal_dim_t)n, query.get(), key.get(), + value.get(), actual_result.get(), host_allocator_, &results)); + iree_status_t status = check_attention_results(stderr, &results); + attention_results_deinitialize(&results); + return status; + } + + private: + iree_allocator_t host_allocator_; +}; + +static const vm::NativeFunction + kAttentionTestModuleFunctions[] = { + vm::MakeNativeFunction( + "generate_random_tensor", + &AttentionTestModuleState::GenerateRandom3dTensor), + vm::MakeNativeFunction( + "check_attention_results", + &AttentionTestModuleState::CheckAttentionResults), +}; + +struct AttentionTestModule final + : public vm::NativeModule { + using vm::NativeModule::NativeModule; + StatusOr> CreateState( + iree_allocator_t host_allocator) override { + return std::make_unique(host_allocator); + } +}; + +} // namespace iree + +static iree_status_t attention_test_module_create( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_vm_module_t** out_module) { + IREE_ASSERT_ARGUMENT(out_module); + *out_module = NULL; + auto module = std::make_unique( + "attention_test", /*version=*/0, instance, host_allocator, + iree::span< + const iree::vm::NativeFunction>( + iree::kAttentionTestModuleFunctions)); + *out_module = module.release()->interface(); + return iree_ok_status(); +} + +int main(int argc, char** argv) { + IREE_TRACE_APP_ENTER(); + + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); + if (argc != 1) { + fprintf(stderr, "use --module= flags to specify the modules to run\n"); + IREE_TRACE_APP_EXIT(EXIT_FAILURE); + return EXIT_FAILURE; + } + + iree_status_t status = iree_test_utils_load_and_run_e2e_tests( + iree_allocator_system(), attention_test_module_create); + int exit_code = EXIT_SUCCESS; + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + bool is_unavailable = iree_status_is_unavailable(status); + iree_status_free(status); + exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; + } + + IREE_TRACE_APP_EXIT(exit_code); + return exit_code; +}