diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 22a97478a4f70e..1757e27a904fff 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1671,6 +1671,7 @@ cc_library( deps = [ "//xla:literal", "//xla:literal_util", + "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", "//xla:types", diff --git a/xla/service/gpu/transforms/gemm_rewriter.cc b/xla/service/gpu/transforms/gemm_rewriter.cc index dc0ae0bfdd1b35..7889c3055c4e6e 100644 --- a/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/xla/service/gpu/transforms/gemm_rewriter.cc @@ -46,8 +46,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -362,27 +364,61 @@ std::optional MatchFp8Param(HloInstruction *instr) { // dimension. Keeps the layout the same. HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, absl::Span batch_dims) { + auto input_shape = instr->shape(); // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. - std::vector permutation(instr->shape().dimensions_size(), -1); + std::vector permutation(input_shape.dimensions_size(), -1); // Discard the batch dimensions. for (int64_t batch_dim : batch_dims) { permutation[batch_dim] = batch_dim; } // Identify the non-contracting dimension. int non_contracting_dim; - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + for (int i = 0; i < input_shape.dimensions_size(); ++i) { if (permutation[i] == -1 && contracting_dim != i) { non_contracting_dim = i; } } - permutation[non_contracting_dim] = contracting_dim; - permutation[contracting_dim] = non_contracting_dim; - Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); - *new_shape.mutable_layout() = instr->shape().layout(); - return instr->AddInstruction( - HloInstruction::CreateTranspose(new_shape, instr, permutation)); + if (Layout::Equal()(input_shape.layout(), + LayoutUtil::GetDefaultLayoutForShape(input_shape))) { + permutation[non_contracting_dim] = contracting_dim; + permutation[contracting_dim] = non_contracting_dim; + + Shape new_shape = ShapeUtil::PermuteDimensions(permutation, input_shape); + *new_shape.mutable_layout() = input_shape.layout(); + + return instr->AddInstruction( + HloInstruction::CreateTranspose(new_shape, instr, permutation)); + } + + Shape normalized_input_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto a0 = MakeBitcastHlo(instr, normalized_input_shape); + + std::vector layout_permuation( + input_shape.layout().minor_to_major().begin(), + input_shape.layout().minor_to_major().end()); + absl::c_reverse(layout_permuation); + auto inv_perm = InversePermutation(layout_permuation); + + int new_contracting_dim = inv_perm[contracting_dim]; + int new_non_contracting_dim = inv_perm[non_contracting_dim]; + absl::c_iota(permutation, 0); + std::swap(permutation[new_contracting_dim], + permutation[new_non_contracting_dim]); + + Shape transpose_shape = + ShapeUtil::PermuteDimensions(permutation, a0->shape()); + *transpose_shape.mutable_layout() = a0->shape().layout(); + + HloInstruction *normalized_transpose = instr->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); + + Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); + *final_shape.mutable_layout() = input_shape.layout(); + return MakeBitcastHlo(normalized_transpose, final_shape); } // If the bias is a sequence of ops that depend only on broadcasts of diff --git a/xla/service/gpu/transforms/gemm_rewriter_test.cc b/xla/service/gpu/transforms/gemm_rewriter_test.cc index 140787413d0f67..721f262822fb46 100644 --- a/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -5032,6 +5032,58 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { + const char* hlo_text = R"( +HloModule test + ENTRY test { + x = <>[2,64,32]{1,2,0} parameter(0) + y = <>[2,32,16]{2,1,0} parameter(1) + x_scale = f32[] parameter(2) + y_scale = f32[] parameter(3) + dq_scale = f32[] multiply(x_scale, y_scale) + dq_scale_bcast = f32[2,64,16] broadcast(dq_scale), dimensions={} + out.0 = f32[2,64,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,64,16] multiply(out.0, dq_scale_bcast) + } +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite( + hlo_text, + GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,64,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,64,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,64,32]{1,2,0} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[2,32,64]{2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[2,64,32]{2,1,0} transpose([[P0_BT]]), dimensions={0,2,1} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,64]{1,2,0} bitcast([[P0_TR]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[2,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,64,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["2"] +; CHECK-DAG: "lhs_batch_dimensions":["0"] +; CHECK-DAG: "rhs_batch_dimensions":["0"] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { const char* hlo_text = R"( HloModule test