From 47dcfcf43908584d453f63008c9d68b5e7dae9c3 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 19 Sep 2024 20:22:02 +0000 Subject: [PATCH] Parameterize elemental_ir_emitter_test.cc float tests --- xla/service/elemental_ir_emitter_test.cc | 609 +++++++---------------- 1 file changed, 181 insertions(+), 428 deletions(-) diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 9ee2680065a26..5e61932ae665a 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" @@ -77,6 +78,23 @@ class ElementalIrEmitterExecutionTestWithoutFastMinMax } }; +template +class ElementalIrEmitterExecutionTypedTest + : public ElementalIrEmitterExecutionTest { + protected: + const std::string& TypeName() { + return primitive_util::LowercasePrimitiveTypeName( + primitive_util::NativeToPrimitiveType()); + } +}; + +using FloatTypes = + ::testing::Types; + +TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); + XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { const std::string hlo_text = R"( HloModule FusedDot @@ -229,473 +247,208 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{(0.)})); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 - (f16_ f16[], f32_ f32[], f64_ f64[]) -> (bf16[], bf16[], bf16[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - converted_f16 = bf16[] convert(f16[] f16_) - converted_f32 = bf16[] convert(f32[] f32_) - converted_f64 = bf16[] convert(f64[] f64_) - ROOT tuple = (bf16[], bf16[], bf16[]) tuple(converted_f16, converted_f32, - converted_f64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = bf16[] convert(s8[] s8_) - converted_s16 = bf16[] convert(s16[] s16_) - converted_s32 = bf16[] convert(s32[] s32_) - converted_s64 = bf16[] convert(s64[] s64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = bf16[] convert(u8[] u8_) - converted_u16 = bf16[] convert(u16[] u16_) - converted_u32 = bf16[] convert(u32[] u32_) - converted_u64 = bf16[] convert(u64[] u64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_f16 bf16[], to_f32 bf16[], to_f64 bf16[]) -> (f16[], f32[], f64[]) { - to_f16 = bf16[] parameter(0) - to_f32 = bf16[] parameter(1) - to_f64 = bf16[] parameter(2) - f16_ = f16[] convert(bf16[] to_f16) - f32_ = f32[] convert(bf16[] to_f32) - f64_ = f64[] convert(bf16[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_s8 bf16[], to_s16 bf16[], to_s32 bf16[], - to_s64 bf16[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = bf16[] parameter(0) - to_s16 = bf16[] parameter(1) - to_s32 = bf16[] parameter(2) - to_s64 = bf16[] parameter(3) - s8_ = s8[] convert(bf16[] to_s8) - s16_ = s16[] convert(bf16[] to_s16) - s32_ = s32[] convert(bf16[] to_s32) - s64_ = s64[] convert(bf16[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_u8 bf16[], to_u16 bf16[], to_u32 bf16[], - to_u64 bf16[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = bf16[] parameter(0) - to_u16 = bf16[] parameter(1) - to_u32 = bf16[] parameter(2) - to_u64 = bf16[] parameter(3) - u8_ = u8[] convert(bf16[] to_u8) - u16_ = u16[] convert(bf16[] to_u16) - u32_ = u32[] convert(bf16[] to_u32) - u64_ = u64[] convert(bf16[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_c64 bf16[], to_c128 bf16[]) -> (c64[], c128[]) { - to_c64 = bf16[] parameter(0) - to_c128 = bf16[] parameter(1) - c64_ = c64[] convert(bf16[] to_c64) - c128_ = c128[] convert(bf16[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareBF16) { - constexpr char hlo_text[] = R"( - HloModule compareBF16 - ENTRY main { - p0 = bf16[4] parameter(0) - p1 = bf16[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToBF16(lhs); - rhs = LiteralUtil::ConvertF32ToBF16(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaBF16) { - constexpr char hlo_text[] = R"( - HloModule IotaBF16 - ENTRY main { - ROOT iota_ = bf16[4] iota(), iota_dimension=0 +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, BatchDotBF16) { - const char* const hlo_text = R"( - HloModule matmul - + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - x = bf16[8,16] parameter(0) - y = bf16[8,16,32] parameter(1) - ROOT dot = bf16[8,32] dot(x, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} + f16_ = f16[] parameter(0) + f32_ = f32[] parameter(1) + f64_ = f64[] parameter(2) + bf16_ = bf16[] parameter(3) + converted_f16 = ${tname}[] convert(f16_) + converted_f32 = ${tname}[] convert(f32_) + converted_f64 = ${tname}[] convert(f64_) + converted_bf16 = ${tname}[] convert(bf16_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( + converted_f16, converted_f32, converted_f64, converted_bf16) } - )"; - HloModuleConfig config; - DebugOptions debug_options = GetDebugOptionsForTest(); - config.set_debug_options(debug_options); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text, config)); - EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e4m3fnuz[] convert(f16[] f16_) - converted_f32 = f8e4m3fnuz[] convert(f32[] f32_) - converted_f64 = f8e4m3fnuz[] convert(f64[] f64_) - converted_bf16 = f8e4m3fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertSignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { s8_ = s8[] parameter(0) s16_ = s16[] parameter(1) s32_ = s32[] parameter(2) s64_ = s64[] parameter(3) - converted_s8 = f8e4m3fnuz[] convert(s8[] s8_) - converted_s16 = f8e4m3fnuz[] convert(s16[] s16_) - converted_s32 = f8e4m3fnuz[] convert(s32[] s32_) - converted_s64 = f8e4m3fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_s8 = ${tname}[] convert(s8_) + converted_s16 = ${tname}[] convert(s16_) + converted_s32 = ${tname}[] convert(s32_) + converted_s64 = ${tname}[] convert(s64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_s8, converted_s16, converted_s32, converted_s64) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertUnsignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { u8_ = u8[] parameter(0) u16_ = u16[] parameter(1) u32_ = u32[] parameter(2) u64_ = u64[] parameter(3) - converted_u8 = f8e4m3fnuz[] convert(u8[] u8_) - converted_u16 = f8e4m3fnuz[] convert(u16[] u16_) - converted_u32 = f8e4m3fnuz[] convert(u32[] u32_) - converted_u64 = f8e4m3fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_u8 = ${tname}[] convert(u8_) + converted_u16 = ${tname}[] convert(u16_) + converted_u32 = ${tname}[] convert(u32_) + converted_u64 = ${tname}[] convert(u64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_u8, converted_u16, converted_u32, converted_u64) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_f16 f8e4m3fnuz[], to_f32 f8e4m3fnuz[], to_f64 f8e4m3fnuz[], to_bf16 f8e4m3fnuz[]) -> (f16[], f32[], f64[], bf16[]) { - to_f16 = f8e4m3fnuz[] parameter(0) - to_f32 = f8e4m3fnuz[] parameter(1) - to_f64 = f8e4m3fnuz[] parameter(2) - to_bf16 = f8e4m3fnuz[] parameter(3) - f16_ = f16[] convert(f8e4m3fnuz[] to_f16) - f32_ = f32[] convert(f8e4m3fnuz[] to_f32) - f64_ = f64[] convert(f8e4m3fnuz[] to_f64) - bf16_ = bf16[] convert(f8e4m3fnuz[] to_f64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToFloats) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_f16 = ${tname}[] parameter(0) + to_f32 = ${tname}[] parameter(1) + to_f64 = ${tname}[] parameter(2) + to_bf16 = ${tname}[] parameter(3) + f16_ = f16[] convert(to_f16) + f32_ = f32[] convert(to_f32) + f64_ = f64[] convert(to_f64) + bf16_ = bf16[] convert(to_f64) ROOT tuple = (f16[], f32[], f64[], bf16[]) tuple(f16_, f32_, f64_, bf16_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_s8 f8e4m3fnuz[], to_s16 f8e4m3fnuz[], to_s32 f8e4m3fnuz[], - to_s64 f8e4m3fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e4m3fnuz[] parameter(0) - to_s16 = f8e4m3fnuz[] parameter(1) - to_s32 = f8e4m3fnuz[] parameter(2) - to_s64 = f8e4m3fnuz[] parameter(3) - s8_ = s8[] convert(f8e4m3fnuz[] to_s8) - s16_ = s16[] convert(f8e4m3fnuz[] to_s16) - s32_ = s32[] convert(f8e4m3fnuz[] to_s32) - s64_ = s64[] convert(f8e4m3fnuz[] to_s64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToSigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_s8 = ${tname}[] parameter(0) + to_s16 = ${tname}[] parameter(1) + to_s32 = ${tname}[] parameter(2) + to_s64 = ${tname}[] parameter(3) + s8_ = s8[] convert(to_s8) + s16_ = s16[] convert(to_s16) + s32_ = s32[] convert(to_s32) + s64_ = s64[] convert(to_s64) ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_u8 f8e4m3fnuz[], to_u16 f8e4m3fnuz[], to_u32 f8e4m3fnuz[], - to_u64 f8e4m3fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e4m3fnuz[] parameter(0) - to_u16 = f8e4m3fnuz[] parameter(1) - to_u32 = f8e4m3fnuz[] parameter(2) - to_u64 = f8e4m3fnuz[] parameter(3) - u8_ = u8[] convert(f8e4m3fnuz[] to_u8) - u16_ = u16[] convert(f8e4m3fnuz[] to_u16) - u32_ = u32[] convert(f8e4m3fnuz[] to_u32) - u64_ = u64[] convert(f8e4m3fnuz[] to_u64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToUnsigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_u8 = ${tname}[] parameter(0) + to_u16 = ${tname}[] parameter(1) + to_u32 = ${tname}[] parameter(2) + to_u64 = ${tname}[] parameter(3) + u8_ = u8[] convert(to_u8) + u16_ = u16[] convert(to_u16) + u32_ = u32[] convert(to_u32) + u64_ = u64[] convert(to_u64) ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_c64 f8e4m3fnuz[], to_c128 f8e4m3fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e4m3fnuz[] parameter(0) - to_c128 = f8e4m3fnuz[] parameter(1) - c64_ = c64[] convert(f8e4m3fnuz[] to_c64) - c128_ = c128[] convert(f8e4m3fnuz[] to_c128) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToComplex) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_c64 = ${tname}[] parameter(0) + to_c128 = ${tname}[] parameter(1) + c64_ = c64[] convert(to_c64) + c128_ = c128[] convert(to_c128) ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E4FNUZ +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, CompareFloat) { + auto tname = this->TypeName(); + if (std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - p0 = f8e4m3fnuz[4] parameter(0) - p1 = f8e4m3fnuz[4] parameter(1) + p0 = ${tname}[4] parameter(0) + p1 = ${tname}[4] parameter(1) ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E4FNUZ +})", + {{"${tname}", tname}}); + Literal lhs = LiteralUtil::CreateR1( + {TypeParam(1.), TypeParam(2.), TypeParam(3.), TypeParam(4.)}); + Literal rhs = LiteralUtil::CreateR1( + {TypeParam(4.), TypeParam(4.), TypeParam(2.), TypeParam(1.)}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {&lhs, &rhs}); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 + ROOT iota_ = ${tname}[4] iota(), iota_dimension=0 } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e5m2fnuz[] convert(f16[] f16_) - converted_f32 = f8e5m2fnuz[] convert(f32[] f32_) - converted_f64 = f8e5m2fnuz[] convert(f64[] f64_) - converted_bf16 = f8e5m2fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = f8e5m2fnuz[] convert(s8[] s8_) - converted_s16 = f8e5m2fnuz[] convert(s16[] s16_) - converted_s32 = f8e5m2fnuz[] convert(s32[] s32_) - converted_s64 = f8e5m2fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {}); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = f8e5m2fnuz[] convert(u8[] u8_) - converted_u16 = f8e5m2fnuz[] convert(u16[] u16_) - converted_u32 = f8e5m2fnuz[] convert(u32[] u32_) - converted_u64 = f8e5m2fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_f16 f8e5m2fnuz[], to_f32 f8e5m2fnuz[], to_f64 f8e5m2fnuz[]) -> (f16[], f32[], f64[]) { - to_f16 = f8e5m2fnuz[] parameter(0) - to_f32 = f8e5m2fnuz[] parameter(1) - to_f64 = f8e5m2fnuz[] parameter(2) - f16_ = f16[] convert(f8e5m2fnuz[] to_f16) - f32_ = f32[] convert(f8e5m2fnuz[] to_f32) - f64_ = f64[] convert(f8e5m2fnuz[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_s8 f8e5m2fnuz[], to_s16 f8e5m2fnuz[], to_s32 f8e5m2fnuz[], - to_s64 f8e5m2fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e5m2fnuz[] parameter(0) - to_s16 = f8e5m2fnuz[] parameter(1) - to_s32 = f8e5m2fnuz[] parameter(2) - to_s64 = f8e5m2fnuz[] parameter(3) - s8_ = s8[] convert(f8e5m2fnuz[] to_s8) - s16_ = s16[] convert(f8e5m2fnuz[] to_s16) - s32_ = s32[] convert(f8e5m2fnuz[] to_s32) - s64_ = s64[] convert(f8e5m2fnuz[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_u8 f8e5m2fnuz[], to_u16 f8e5m2fnuz[], to_u32 f8e5m2fnuz[], - to_u64 f8e5m2fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e5m2fnuz[] parameter(0) - to_u16 = f8e5m2fnuz[] parameter(1) - to_u32 = f8e5m2fnuz[] parameter(2) - to_u64 = f8e5m2fnuz[] parameter(3) - u8_ = u8[] convert(f8e5m2fnuz[] to_u8) - u16_ = u16[] convert(f8e5m2fnuz[] to_u16) - u32_ = u32[] convert(f8e5m2fnuz[] to_u32) - u64_ = u64[] convert(f8e5m2fnuz[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_c64 f8e5m2fnuz[], to_c128 f8e5m2fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e5m2fnuz[] parameter(0) - to_c128 = f8e5m2fnuz[] parameter(1) - c64_ = c64[] convert(f8e5m2fnuz[] to_c64) - c128_ = c128[] convert(f8e5m2fnuz[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E5FNUZ - ENTRY main { - p0 = f8e5m2fnuz[4] parameter(0) - p1 = f8e5m2fnuz[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule matmul -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E5FNUZ ENTRY main { - ROOT iota_ = f8e5m2fnuz[4] iota(), iota_dimension=0 + x = ${tname}[8,16] parameter(0) + y = ${tname}[8,16,32] parameter(1) + ROOT dot = ${tname}[8,32] dot(x, y), lhs_batch_dims={0}, + rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} } - )"; + )", + {{"${tname}", tname}}); + HloModuleConfig config; + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + config.set_debug_options(debug_options); - RunTest(hlo_text, {}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloTestBase::ParseAndReturnVerifiedModule(hlo_text, config)); + EXPECT_TRUE( + HloTestBase::RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); } XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax,