Conversation
| rhs_scale_k *= rhs_scale_shape.dimensions(dim); | ||
| } | ||
| if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 || |
There was a problem hiding this comment.
bug: Integer division may falsely accept non-aligned block sizes
The block-size check k / lhs_scale_k != 32 uses integer division. If k is not perfectly divisible by lhs_scale_k, the integer quotient could still equal 32 even though the actual block size is fractional.
For example: k=65, lhs_scale_k=2 → 65/2 = 32 (integer), but the real block size is 32.5 elements.
Consider adding an explicit divisibility check:
| rhs_scale_k *= rhs_scale_shape.dimensions(dim); | |
| } | |
| if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 || | |
| if (lhs_scale_k == 0 || k % lhs_scale_k != 0 || k / lhs_scale_k != 32 || | |
| rhs_scale_k == 0 || k % rhs_scale_k != 0 || k / rhs_scale_k != 32) { |
| } else if (IsScaledDotFusion(instr)) { | ||
| const HloInstruction* scaled_dot = hlo_query::GetFirstInstructionWithOpcode( | ||
| *instr.fused_instructions_computation(), HloOpcode::kScaledDot); | ||
| TF_RET_CHECK(scaled_dot != nullptr); | ||
| HloComputation* parent = instr.parent(); | ||
|
|
||
| TF_RET_CHECK(instr.operand_count() == 4); | ||
| HloInstruction* lhs = instr.mutable_operand(0); | ||
| HloInstruction* rhs = instr.mutable_operand(1); | ||
| HloInstruction* lhs_scale = instr.mutable_operand(2); | ||
| HloInstruction* rhs_scale = instr.mutable_operand(3); | ||
|
|
||
| const Shape& result_shape = scaled_dot->shape(); | ||
| int64_t workspace_size = gemm_key.autotune_workspace_size(); | ||
| Shape workspace_shape = ShapeUtil::MakeShape(S8, {workspace_size}); | ||
| Shape output_shape = | ||
| ShapeUtil::MakeTupleShape({result_shape, workspace_shape}); | ||
|
|
||
| GpuBackendConfig gpu_backend_config; | ||
| GemmBackendConfig& gemm_config = | ||
| *gpu_backend_config.mutable_gemm_backend_config(); | ||
| *gemm_config.mutable_dot_dimension_numbers() = | ||
| scaled_dot->dot_dimension_numbers(); | ||
| gemm_config.set_alpha_real(1.0); | ||
| gemm_config.set_alpha_imag(0.0); | ||
| gemm_config.set_beta(0.0); | ||
| gemm_config.set_scale_mode( | ||
| static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling)); | ||
| gemm_config.set_selected_algorithm(gemm_key.algorithm()); | ||
| gemm_config.set_autotune_workspace_size(workspace_size); | ||
|
|
||
| HloInstruction* custom_call = | ||
| parent->AddInstruction(HloInstruction::CreateCustomCall( | ||
| output_shape, {lhs, rhs, lhs_scale, rhs_scale}, | ||
| kCublasLtMatmulMxCallTarget)); | ||
| TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); | ||
| HloInstruction* gte = parent->AddInstruction( | ||
| HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0)); | ||
| return parent->ReplaceInstruction(&instr, gte); |
There was a problem hiding this comment.
concern: Graph transformation in ApplyConfig is unusual
ApplyConfig typically only sets backend config on the existing instruction, but here it replaces the fusion instruction with a new kCublasLtMatmulMxCallTarget custom call via parent->ReplaceInstruction(&instr, gte). After this call, &instr becomes a dangling reference.
Two concerns:
- Verify the autotuner framework handles instruction replacement during
ApplyConfig(not just config changes). - The
TF_RET_CHECK(instr.operand_count() == 4)at line 376 will crash with a fatal error if the count is wrong, butIsScaledDotFusiondoesn't verify operand count. Consider moving this check intoIsValidMxScaledDotorIsScaledDotFusionso thatGetSupportedConfigsreturns empty instead of crashing inApplyConfig.
| int64 autotune_workspace_size = 20; | ||
|
|
||
| // Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX). |
There was a problem hiding this comment.
nit: Consider using a protobuf enum for type safety
Using a raw int32 with static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling) throughout the codebase means that if the C++ enum values change, the proto wire values silently change. A proper protobuf enum type would provide type safety, self-documenting wire format, and better backward compatibility.
Not blocking, but worth considering for a follow-up.
|
|
||
| gemm_backend_config.set_scale_mode( | ||
| static_cast<int32_t>(se::gpu::ScaleMode::kTensorScaling)); |
There was a problem hiding this comment.
question: Behavioral change for existing FP8 matmuls
This sets scale_mode=kTensorScaling for existing FP8 matmuls that previously had scale_mode unset (defaulting to kNone). Downstream in hip_blas_lt.cc:GetAlgorithms, the code now uses a switch(scale_mode) instead of the old IsFP8() lambda check.
If any existing FP8 GemmBackendConfig protos are read from cache (rather than re-generated via the gemm rewriter), they would have scale_mode=0 (kNone) and the switch would skip setting scale pointers, causing algorithm lookup to fail.
Can you confirm existing FP8 matmul configs are always regenerated (not cached) so the new field is always present?
| DeviceAddressBase a_scale = args.a_scale, b_scale = args.b_scale; | ||
| if (must_swap_operands_) { | ||
| std::swap(a, b); |
There was a problem hiding this comment.
nit: Inconsistency with #else branch (ROCm < 6.0)
The swapped a_scale/b_scale locals are correctly used in the #if TF_ROCM_VERSION >= 60000 branch, but the #else branch (~line 488 in the original file) still refers to args.a_scale and args.b_scale. While logically equivalent for the null check, this inconsistency makes the code harder to reason about. Consider updating the #else branch to use the swapped locals too.
| ASSIGN_OR_RETURN(ShapedSlice c, GetShapedSliceForHlo(instr, output_index)); | ||
| ASSIGN_OR_RETURN(ShapedSlice d, GetShapedSliceForHlo(instr, output_index)); |
There was a problem hiding this comment.
nit: c and d alias the same buffer
Both c and d are set to the same output slice. This works because beta=0 (so the c term in alpha*A*B + beta*C vanishes), but it's fragile — if a future change enables non-zero beta for MX matmul, the output would be aliased with the accumulation input. Consider adding a comment clarifying this is intentional because beta is always 0 for MX.
|
|
||
| #if TF_ROCM_VERSION >= 70000 | ||
| // MX FP4 (F4E2M1FN) type combinations | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F) |
There was a problem hiding this comment.
question: OCP vs FNUZ FP8 type mapping
The TYPED_MATMUL entries use HIP_R_8F_E4M3 and HIP_R_8F_E5M2 (OCP/non-FNUZ variants) as mixed-type partners with HIP_R_4F_E2M1. The validator in hipblaslt.cc accepts F8E4M3FN and F8E5M2 as valid input types.
Older ROCm versions used FNUZ variants (HIP_R_8F_E4M3_FNUZ / HIP_R_8F_E5M2_FNUZ). Can you confirm that HIP_R_8F_E4M3 / HIP_R_8F_E5M2 correctly map to XLA's F8E4M3FN / F8E5M2 types on the target ROCm 7.0+ platform? The AsHipblasDataType function in hip_blas_utils.cc should be consistent with these entries.
| for (int64_t dim : dot_dims.lhs_batch_dimensions()) { | ||
| batch_size *= lhs_shape.dimensions(dim); | ||
| } | ||
| if (batch_size != 1) { | ||
| VLOG(2) << "hipBLASLt MX: batch_size > 1 not supported, got " << batch_size; | ||
| return false; | ||
| } |
There was a problem hiding this comment.
nit: Batch restriction not documented
The batch_size != 1 check rejects batched MX operations, but it only computes batch_size from lhs batch dimensions. Consider adding a brief comment explaining this is a hipBLASLt limitation (or a TODO to revisit when hipBLASLt gains batch support for MX). The test kMxFp8BatchedHlo passes [1,32,256] with batch_size=1, which validates the path but doesn't exercise real batching.
| switch (op_desc_.scale_mode()) { | ||
| case gpu::ScaleMode::kNone: | ||
| break; | ||
| case gpu::ScaleMode::kTensorScaling: { | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| break; | ||
| } | ||
| case gpu::ScaleMode::kBlockScaling: { | ||
| #if TF_ROCM_VERSION >= 70000 | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| hipblasLtMatmulMatrixScale_t mx_scale = | ||
| HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale)); | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale)); | ||
| #else | ||
| return absl::InternalError("Block scaling requires ROCm >= 7.0"); | ||
| #endif | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
The old code auto-detected FP8 inputs via IsFP8(a_desc_) && IsFP8(b_desc_) and unconditionally set dummy scale pointers. Now this is gated on scale_mode == kTensorScaling.
Any pre-existing serialized/cached FP8 matmul configs (e.g., from an autotune cache on disk) that were created before this PR will have scale_mode = 0 (kNone) in their proto, because the scale_mode field didn't exist and proto defaults to 0. When those cached configs are deserialized and used, GetAlgorithms() will skip setting the dummy a/b scale pointers, and algorithm enumeration may fail or return wrong algorithms.
Consider either:
- Treating FP8 +
kNonethe same askTensorScalingfor backward compatibility (i.e., fall through fromkNonetokTensorScalingwhen the input types are FP8), or - Bumping the autotune cache version to invalidate old caches.
| #include "xla/service/compiler.h" | ||
| #include "xla/service/gpu/backend_configs.pb.h" | ||
| #include "xla/service/gpu/cublas_cudnn.h" | ||
| #include "xla/service/gpu/ir_emission_utils.h" |
There was a problem hiding this comment.
ir_emission_utils.h is included here but the only symbol used from it is kTritonGemmFusionKind. Consider whether a more targeted include (or forward declaration) would be preferable to pulling in the full ir_emission_utils.h header, which is a large header with many dependencies.
| bool IsValidMxScaledDot(const HloInstruction* scaled_dot) { | ||
| const Shape& lhs_shape = scaled_dot->operand(0)->shape(); | ||
| const Shape& rhs_shape = scaled_dot->operand(1)->shape(); | ||
| const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape(); | ||
| const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape(); | ||
| const Shape& output_shape = scaled_dot->shape(); | ||
| const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers(); |
There was a problem hiding this comment.
The IsValidMxScaledDot function accesses scaled_dot->operand(2) and scaled_dot->operand(3) without first verifying that scaled_dot->operand_count() >= 4. If a malformed kScaledDot instruction is encountered with fewer operands, this would be an out-of-bounds access. Consider adding a guard:
| bool IsValidMxScaledDot(const HloInstruction* scaled_dot) { | |
| const Shape& lhs_shape = scaled_dot->operand(0)->shape(); | |
| const Shape& rhs_shape = scaled_dot->operand(1)->shape(); | |
| const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape(); | |
| const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape(); | |
| const Shape& output_shape = scaled_dot->shape(); | |
| const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers(); | |
| bool IsValidMxScaledDot(const HloInstruction* scaled_dot) { | |
| if (scaled_dot->operand_count() < 4) { | |
| VLOG(2) << "hipBLASLt MX: scaled-dot must have 4 operands, got " | |
| << scaled_dot->operand_count(); | |
| return false; | |
| } | |
| const Shape& lhs_shape = scaled_dot->operand(0)->shape(); | |
| const Shape& rhs_shape = scaled_dot->operand(1)->shape(); | |
| const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape(); | |
| const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape(); | |
| const Shape& output_shape = scaled_dot->shape(); | |
| const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers(); |
| if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 || | ||
| k / rhs_scale_k != 32) { | ||
| VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs=" | ||
| << (lhs_scale_k > 0 ? k / lhs_scale_k : 0) | ||
| << " rhs=" << (rhs_scale_k > 0 ? k / rhs_scale_k : 0); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Division by zero risk: if lhs_scale_k or rhs_scale_k is zero, k / lhs_scale_k would trigger undefined behavior before the != 32 check. The lhs_scale_k == 0 guard short-circuits via ||, so k / lhs_scale_k is only evaluated when lhs_scale_k != 0 -- this is correct. However, the same is not true for rhs_scale_k: the expression rhs_scale_k == 0 || k / rhs_scale_k != 32 is also correct due to short-circuit evaluation.
On second look, the logic is fine due to || short-circuiting. Disregard.
| HloInstruction* custom_call = | ||
| parent->AddInstruction(HloInstruction::CreateCustomCall( | ||
| output_shape, {lhs, rhs, lhs_scale, rhs_scale}, | ||
| kCublasLtMatmulMxCallTarget)); | ||
| TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); | ||
| HloInstruction* gte = parent->AddInstruction( | ||
| HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0)); | ||
| return parent->ReplaceInstruction(&instr, gte); |
There was a problem hiding this comment.
ApplyConfig replaces the fusion instruction instr with a GetTupleElement(custom_call, 0). The original fusion instruction is removed from the computation. If instr was the root instruction of the entry computation, ReplaceInstruction changes the root to the GTE. This means the module's result layout should be updated, similar to how the IsCublasLtMatmul path updates the result layout (lines 360-365). The MX path does not do this.
Additionally, the dead fusion computation (the one that contained the kScaledDot) is not explicitly cleaned up. While HLO DCE may eventually remove it, this could leave orphaned computations in the module during autotuning.
|
|
||
| // Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX). | ||
| int32 scale_mode = 21; |
There was a problem hiding this comment.
The scale_mode field is stored as int32 in the proto rather than a proper proto enum. This means there is no validation at the proto level -- any int32 value can be deserialized into scale_mode. The code then does static_cast<ScaleMode>(config.scale_mode()) in multiple places without bounds checking, which could create invalid enum values if the proto contains unexpected data (e.g., from a newer version).
Consider using a proto enum type here, or adding a validation check at deserialization points.
| // A call to hipBLASLt for block-scaled matrix multiplication in MX formats. | ||
| inline constexpr absl::string_view kCublasLtMatmulMxCallTarget = | ||
| "__cublas$lt$matmul$mx"; |
There was a problem hiding this comment.
The naming convention kCublasLtMatmulMxCallTarget = "__cublas$lt$matmul$mx" uses the "cublas" prefix, but the comment says "Calls into hipBLASLt." This is a ROCm-only feature that will never call into cuBLAS/cuBLASLt. Using the cublas prefix here could be confusing. Consider whether a different naming convention would be clearer (though I acknowledge this follows the existing pattern of kCublasLtMatmulF8CallTarget which is also used for hipBLASLt on ROCm).
| ASSIGN_OR_RETURN(ShapedSlice c, GetShapedSliceForHlo(instr, output_index)); | ||
| ASSIGN_OR_RETURN(ShapedSlice d, GetShapedSliceForHlo(instr, output_index)); |
There was a problem hiding this comment.
The MX thunk emitter uses ASSIGN_OR_RETURN (without TF_ prefix) for ShapedSlice assignments, which is consistent with nearby code. However, the mix of TF_ASSIGN_OR_RETURN (lines 675-676) and ASSIGN_OR_RETURN (lines 683-691) within the same function is worth noting for consistency.
More importantly: lines 690-691 use the same output_index for both c and d, making c and d point to the same buffer. This is intentional for MX (where beta=0, so C is unused), but it's worth a comment explaining why c aliases d here, similar to how the F8 path (line 607) handles this with a comment about has_matrix_bias.
| bool IsScaledDotFusion(const HloInstruction& instr) { | ||
| if (instr.opcode() != HloOpcode::kFusion) return false; | ||
| auto gpu_config = instr.backend_config<GpuBackendConfig>(); | ||
| if (!gpu_config.ok()) return false; | ||
| if (gpu_config->fusion_backend_config().kind() != kTritonGemmFusionKind) { | ||
| return false; | ||
| } | ||
| return hlo_query::GetFirstInstructionWithOpcode( | ||
| *instr.fused_instructions_computation(), HloOpcode::kScaledDot) != | ||
| nullptr; | ||
| } |
There was a problem hiding this comment.
The IsScaledDotFusion function calls instr.backend_config<GpuBackendConfig>() and silently returns false if it fails. This is fine for a predicate. However, it also uses .value() semantics implicitly by accessing gpu_config->fusion_backend_config() without checking .ok() -- actually it does check .ok() on line 89, so this is correct.
One concern: this function is called from IsSupported, GetSupportedConfigs, ApplyConfig, and GetDefaultConfig. Each time it re-parses the backend config proto. Consider caching or restructuring to avoid redundant proto parsing.
| if (must_swap_operands_) { | ||
| std::swap(a, b); | ||
| if (a_scale != nullptr && b_scale != nullptr) { | ||
| std::swap(a_scale, b_scale); | ||
| } |
There was a problem hiding this comment.
When must_swap_operands_ is true, the scale pointers are also swapped. However, this swap only happens when BOTH a_scale and b_scale are non-null. If only one is non-null (which shouldn't happen for MX, but could for other configurations), the scale would NOT be swapped while the matrices ARE swapped, leading to incorrect scale application. Consider swapping unconditionally when must_swap_operands_ is true, or at least asserting that both are null or both are non-null when swapping is required.
| absl::Status HandleRaggedDot(HloInstruction* instr) override { | ||
| if (!IsGpublasLtSupportedGroupedMatMul(*instr)) { | ||
| return absl::OkStatus(); | ||
| } | ||
| HloRaggedDotInstruction* ragged_dot = | ||
| DynCast<HloRaggedDotInstruction>(instr); | ||
| if (ragged_dot == nullptr) { | ||
| return absl::OkStatus(); | ||
| } | ||
| const auto& ragged_dims = ragged_dot->ragged_dot_dimension_numbers(); | ||
| const auto& dot_dims = ragged_dims.dot_dimension_numbers(); | ||
| if (ragged_dims.lhs_ragged_dimensions().size() != 1) { | ||
| return absl::UnimplementedError("lhs_ragged_dimensions must have size 1"); | ||
| } | ||
| int lhs_ragged_dim = ragged_dims.lhs_ragged_dimensions(0); | ||
|
|
||
| auto isLhsRaggedDimInContractingDim = [](int lhs_ragged_dim, | ||
| const DotDimensionNumbers& dnums) { | ||
| return std::any_of(dnums.lhs_contracting_dimensions().begin(), | ||
| dnums.lhs_contracting_dimensions().end(), | ||
| [&](auto dim) { return dim == lhs_ragged_dim; }); | ||
| }; | ||
|
|
||
| auto isLhsRaggedDimInBatchDim = [](int lhs_ragged_dim, | ||
| const DotDimensionNumbers& dnums) { | ||
| return std::any_of(dnums.lhs_batch_dimensions().begin(), | ||
| dnums.lhs_batch_dimensions().end(), | ||
| [&](auto dim) { return dim == lhs_ragged_dim; }); | ||
| }; | ||
|
|
||
| if (!isLhsRaggedDimInContractingDim(lhs_ragged_dim, dot_dims) && | ||
| !isLhsRaggedDimInBatchDim(lhs_ragged_dim, dot_dims) && | ||
| ragged_dims.rhs_group_dimensions().size() != 1) { | ||
| return absl::UnimplementedError( | ||
| "rhs_group_dimensions must have size equal to 1 when lhs ragged " | ||
| "dimension is a non-contracting dimension"); | ||
| } | ||
| HloInstruction* grouped_gemm_call = | ||
| instr->AddInstruction(HloInstruction::CreateCustomCall( | ||
| ragged_dot->shape(), ragged_dot->mutable_operands(), | ||
| gpu::kCublasLtGroupedMatmulCallTarget)); | ||
|
|
||
| // Create a GroupedGemmBackendConfig based on the instruction. | ||
| TF_ASSIGN_OR_RETURN( | ||
| gpu::GpuBackendConfig gpu_backend_config, | ||
| grouped_gemm_call->backend_config<gpu::GpuBackendConfig>()); | ||
| GroupedGemmBackendConfig& grouped_gemm_backend_config = | ||
| *gpu_backend_config.mutable_grouped_gemm_backend_config(); | ||
| RaggedDotDimensionNumbers& ragged_dot_dimension_numbers = | ||
| *grouped_gemm_backend_config.mutable_ragged_dot_dimension_numbers(); | ||
| ragged_dot_dimension_numbers = ragged_dot->ragged_dot_dimension_numbers(); | ||
|
|
||
| // Create a GemmBackendConfig based on the instruction. | ||
| GemmBackendConfig& gemm_backend_config = | ||
| *grouped_gemm_backend_config.mutable_gemm_backend_config(); | ||
| gemm_backend_config.set_alpha_real(1.0); | ||
| gemm_backend_config.set_alpha_imag(0.0); | ||
| gemm_backend_config.set_beta(0.0); | ||
| *gemm_backend_config.mutable_dot_dimension_numbers() = dot_dims; | ||
|
|
||
| auto attributes = instr->frontend_attributes().map(); | ||
| gemm_backend_config.set_grad_x(attributes["grad_x"] == "true"); | ||
| gemm_backend_config.set_grad_y(attributes["grad_y"] == "true"); | ||
|
|
||
| TF_RETURN_IF_ERROR( | ||
| grouped_gemm_call->set_backend_config(gpu_backend_config)); | ||
|
|
||
| TF_RETURN_IF_ERROR(ReplaceInstruction(instr, grouped_gemm_call)); | ||
| return absl::OkStatus(); | ||
| } |
There was a problem hiding this comment.
bug: No thunk emitter for kCublasLtGroupedMatmulCallTarget
HandleRaggedDot rewrites ragged dots into a kCublasLtGroupedMatmulCallTarget custom call, but thunk_emitter.cc has no handler for this target — EmitCustomCallSwitch does not recognize it. Any ragged dot rewritten by this handler will fail at code generation with an "Unrecognized custom call target" error.
The workspace rewriter (GemmWorkspaceRewriteVisitor) does handle this target (line ~1820), so workspace allocation succeeds, but the resulting instruction cannot be emitted.
Is this intentionally gated behind a follow-up PR that adds the thunk emitter? If so, consider adding a comment or a TODO here noting the dependency.
| bool IsGpublasLtSupportedGroupedMatMul(const HloInstruction& instr) { | ||
| if (instr.opcode() == HloOpcode::kRaggedDot) { | ||
| switch (instr.shape().element_type()) { | ||
| // Only float 16 and bf 16 are supported by HipBlasLt GroupGemm | ||
| case F16: | ||
| case BF16: | ||
| return (((instr.operand(0)->shape().element_type() == F16) || | ||
| (instr.operand(0)->shape().element_type() == BF16)) && | ||
| ((instr.operand(1)->shape().element_type() == F16) || | ||
| (instr.operand(1)->shape().element_type() == BF16))); | ||
| default: | ||
| return false; | ||
| } | ||
| } | ||
| return false; | ||
| } |
There was a problem hiding this comment.
bug: Missing platform guard — could trigger on CUDA
IsGpublasLtSupportedGroupedMatMul does not check whether the target platform is ROCm. If this returns true on CUDA, HandleRaggedDot in gemm_rewriter.cc will convert the ragged dot to a kCublasLtGroupedMatmulCallTarget custom call on CUDA, where there is no emitter for it, causing a compilation failure.
This needs a platform guard (e.g., checking GpuComputeCapability is RocmComputeCapability) or documentation that it is ROCm-only and the caller must gate it.
| // Only float 16 and bf 16 are supported by HipBlasLt GroupGemm | ||
| case F16: | ||
| case BF16: | ||
| return (((instr.operand(0)->shape().element_type() == F16) || | ||
| (instr.operand(0)->shape().element_type() == BF16)) && | ||
| ((instr.operand(1)->shape().element_type() == F16) || | ||
| (instr.operand(1)->shape().element_type() == BF16))); |
There was a problem hiding this comment.
nit: Mixed input types not validated
The comment says "Only float 16 and bf 16 are supported" but the check allows mixed input types (e.g., lhs=F16, rhs=BF16). Does hipBLASLt grouped GEMM actually support mixed F16/BF16 inputs? If not, this should also verify that lhs and rhs have matching element types.
Re-review Summary\n\nThis is a follow-up review. 18 previous inline comments remain open (issues not yet addressed). 8 new findings have been posted inline, focused on the grouped GEMM (ragged dot) additions.\n\nKey new findings:\n- Missing thunk emitter for kCublasLtGroupedMatmulCallTarget -- ragged dots rewritten by HandleRaggedDot cannot be emitted, will fail at codegen\n- Missing platform guard in IsGpublasLtSupportedGroupedMatMul -- could incorrectly trigger on CUDA\n- Magic constant kUserArgsSizeBytes = 196 undocumented\n- Style nits: std::any_of to absl::c_any_of, int to int64_t, typos in comments, oneof with single variant\n- FP4 test tolerance question\n\nGenerated with Claude Code |
| switch (op_desc_.scale_mode()) { | ||
| case gpu::ScaleMode::kNone: | ||
| break; | ||
| case gpu::ScaleMode::kTensorScaling: { | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| break; | ||
| } | ||
| case gpu::ScaleMode::kBlockScaling: { | ||
| #if TF_ROCM_VERSION >= 70000 | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| hipblasLtMatmulMatrixScale_t mx_scale = | ||
| HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale)); | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale)); | ||
| #else | ||
| return absl::InternalError("Block scaling requires ROCm >= 7.0"); | ||
| #endif | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
The old code detected FP8 from the matrix layout data types (IsFP8(a_desc_) && IsFP8(b_desc_)), which worked regardless of whether scale_mode was set. The new code relies entirely on op_desc_.scale_mode() being set correctly.
For deserialized or cached HLO modules that predate this PR, the scale_mode proto field will default to 0 (kNone), even though the matmul is FP8 with tensor scaling. This means GetAlgorithms will skip setting the dummy scale pointers, and hipblasLtMatmulAlgoGetHeuristic may return zero algorithms -- breaking autotuning for existing FP8 matmuls.
Consider either:
- Falling back to the old
IsFP8check whenscale_mode == kNone(e.g., treatkNone+ FP8 layouts askTensorScaling), or - Adding a migration step that sets
scale_modeon deserialized F8 custom calls before autotuning.
| // Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX). | ||
| int32 scale_mode = 21; |
There was a problem hiding this comment.
Using a raw int32 with a comment documenting the enum values (0=none, 1=tensor_scaling, 2=block_scaling) creates a tight coupling between this proto and the C++ ScaleMode enum that is not enforced at the proto schema level. If either side changes independently, the values will silently mismatch.
Consider defining a proper proto enum:
enum ScaleMode {
SCALE_MODE_NONE = 0;
SCALE_MODE_TENSOR_SCALING = 1;
SCALE_MODE_BLOCK_SCALING = 2;
}
ScaleMode scale_mode = 21;This would give you wire compatibility (the encoding is identical) while providing self-documenting proto definitions and validation.
| nullptr; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| bool HipblasLtBackend::IsSupported(const HloInstruction& instr) { | ||
| return IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr); | ||
| if (IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr)) { | ||
| return true; | ||
| } | ||
| if (IsScaledDotFusion(instr)) { | ||
| const auto& gpu_cc = | ||
| target_config().device_description.gpu_compute_capability(); |
There was a problem hiding this comment.
IsScaledDotFusion checks whether a fusion's backend_config has kind == kTritonGemmFusionKind and contains a kScaledDot instruction. However, this function is called from IsSupported(), which does not guard against CUDA platforms. IsSupported() does subsequently check for rocm_compute_capability(), but IsScaledDotFusion is also used directly in GetSupportedConfigs and ApplyConfig, where the platform check is not repeated.
If HipblasLtBackend is only ever instantiated on ROCm (which appears to be the case today), this is safe. But the function name does not make this platform restriction obvious. Consider either:
- Adding a comment documenting the platform assumption, or
- Adding a
DCHECKthat the device is ROCm insideIsScaledDotFusion.
| GpuBackendConfig gpu_backend_config; | ||
| GemmBackendConfig& gemm_config = | ||
| *gpu_backend_config.mutable_gemm_backend_config(); | ||
| *gemm_config.mutable_dot_dimension_numbers() = | ||
| scaled_dot->dot_dimension_numbers(); | ||
| gemm_config.set_alpha_real(1.0); | ||
| gemm_config.set_alpha_imag(0.0); | ||
| gemm_config.set_beta(0.0); | ||
| gemm_config.set_scale_mode( | ||
| static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling)); | ||
| gemm_config.set_selected_algorithm(gemm_key.algorithm()); | ||
| gemm_config.set_autotune_workspace_size(workspace_size); |
There was a problem hiding this comment.
In ApplyConfig for the IsScaledDotFusion branch, the fusion instruction is replaced with a kCublasLtMatmulMx custom call + GTE. However, the precision_config from the original scaled_dot is not propagated to the new GemmBackendConfig. While the default PrecisionConfig (ALG_UNSET, DEFAULT precision) may be adequate here, silently dropping it could lead to subtle differences in compute behavior if the user's HLO specifies non-default precision settings.
Consider copying the precision config:
*gemm_config.mutable_precision_config() = scaled_dot->precision_config();| return false; | ||
| } | ||
|
|
||
| int64_t batch_size = 1; | ||
| for (int64_t dim : dot_dims.lhs_batch_dimensions()) { | ||
| batch_size *= lhs_shape.dimensions(dim); | ||
| } | ||
| if (batch_size != 1) { | ||
| VLOG(2) << "hipBLASLt MX: batch_size > 1 not supported, got " << batch_size; | ||
| return false; |
There was a problem hiding this comment.
IsValidMxScaledDot rejects batch_size > 1 but allows batch_size == 1 (with batch dimensions present). The test MxFp8BatchedCorrectness exercises this with shape [1,32,256] and lhs_batch_dims={0}.
If the intent is that hipBLASLt MX does not support batched operations at all, this check should reject any non-empty lhs_batch_dimensions(). If the intent is only to reject multi-batch cases, the current code is correct but the VLOG message "batch_size > 1 not supported" is misleading -- it should clarify that batch_size=1 (degenerate batch) is allowed.
Also, this only checks lhs_batch_dimensions. If LHS and RHS have different batch dimensions (which would be unusual but valid in the proto), the RHS batch size is not validated.
| DeviceAddressBase a_scale = args.a_scale, b_scale = args.b_scale; | ||
| if (must_swap_operands_) { | ||
| std::swap(a, b); | ||
| if (a_scale != nullptr && b_scale != nullptr) { |
There was a problem hiding this comment.
The scale pointers (a_scale, b_scale) are now swapped together with the operand pointers when must_swap_operands_ is true. This is correct for maintaining the logical correspondence between an operand and its scale.
However, the swap is guarded by a_scale != nullptr && b_scale != nullptr. If only one scale is non-null (which shouldn't happen for MX but could happen in edge cases), the swap would be silently skipped, leading to a scale being applied to the wrong operand. Consider using || instead of &&, or adding a TF_RET_CHECK that both are null or both are non-null.
| auto IsScaledType = [](xla::PrimitiveType dtype) { | ||
| return xla::primitive_util::IsF8Type(dtype) || dtype == xla::F4E2M1FN; | ||
| }; | ||
| if (IsScaledType(lhs_layout.dtype) && | ||
| lhs_layout.order == gpu::MatrixLayout::Order::kColumnMajor) { | ||
| return xla::Internal("The F8 LHS must be column-major"); | ||
| return xla::Internal("The F8/MX LHS must be row-major"); | ||
| } | ||
| if (xla::primitive_util::IsF8Type(rhs_layout.dtype) && | ||
| if (IsScaledType(rhs_layout.dtype) && | ||
| rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { | ||
| return xla::Internal("The F8 RHS must be row-major"); | ||
| return xla::Internal("The F8/MX RHS must be column-major"); | ||
| } |
There was a problem hiding this comment.
The IsScaledType lambda includes F4E2M1FN alongside F8 types. The corresponding error messages were updated from "The F8 LHS must be column-major" to "The F8/MX LHS must be row-major" (and similarly for RHS). The corrected phrasing ("row-major" / "column-major") is good and fixes what appears to have been an incorrect error message in the original code.
However, note that this also changes the validation behavior for F4E2M1FN types: previously they would pass through without any layout check (since IsF8Type didn't match them). Now they are subject to the same row-major LHS / column-major RHS constraint. Make sure this is intentional and that hipBLASLt actually requires this layout for MX types.
|
|
||
| TEST_F(HipblasLtMxExecutionTest, MxFp4BatchedCorrectness) { | ||
| RunMxCorrectnessTest(kMxFp4BatchedHlo, | ||
| ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5)); |
There was a problem hiding this comment.
The FP4 tests use the same error tolerances as FP8 tests (aabs=1e-4, arel=1e-5). FP4 (F4E2M1FN) has only 2 exponent bits and 1 mantissa bit, so it is significantly less precise than FP8. These tight tolerances may cause flaky failures depending on matrix values.
If the reference path decomposes the scaled dot into equivalent FP32 arithmetic, the comparison should still be close because both paths effectively compute the same low-precision operation. But if there are any differences in intermediate precision or rounding, FP4 would amplify them. Consider using a slightly wider tolerance for FP4 tests, or document why the tight tolerance is expected to hold.
| // A call to hipBLASLt for block-scaled matrix multiplication in MX formats. | ||
| inline constexpr absl::string_view kCublasLtMatmulMxCallTarget = | ||
| "__cublas$lt$matmul$mx"; |
There was a problem hiding this comment.
The naming convention kCublasLtMatmulMxCallTarget = "__cublas$lt$matmul$mx" uses the "cublas" prefix, but this target is specific to hipBLASLt on ROCm. This is consistent with the existing pattern where kCublasLtMatmulCallTarget and kCublasLtMatmulF8CallTarget are also used on ROCm (the naming convention follows the CUDA origin). The comment correctly clarifies "Calls into hipBLASLt." Just noting this for awareness -- the naming may cause confusion if CUDA ever gets native MX support via cuBLASLt.
| if (IsCublasLtMatmulMx(*hlo)) { | ||
| return EmitCublasLtMatmulThunkMx(custom_call); | ||
| } |
There was a problem hiding this comment.
The new EmitCublasLtMatmulThunkMx is registered in EmitCustomCallSwitch, which is the correct integration point. However, IsCublasLtMatmulMx will match on CUDA platforms too, since it only checks the custom call target string. If a CUDA device somehow encounters a __cublas$lt$matmul$mx custom call (e.g., from a portable saved model), this will attempt to emit a CublasLtMatmulThunk on CUDA where ScaleMode::kBlockScaling is not handled in cuda_blas_lt.cc. This would likely fail at GetMatmulPlan time, but the error message would be confusing.
Consider adding a platform check or at minimum a clear error message for the CUDA case.
| #if TF_ROCM_VERSION >= 70000 | ||
| // MX FP4 (F4E2M1FN) type combinations | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F) | ||
|
|
||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16F) | ||
|
|
||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_16F) | ||
|
|
||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F) | ||
|
|
||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F) | ||
| TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F) | ||
| #endif | ||
|
|
||
| // Other data types: | ||
| TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF) |
There was a problem hiding this comment.
The MX TYPED_MATMUL entries cover F4E2M1FN crossed with {F4E2M1FN, F8E4M3, F8E5M2} and vice versa. But the IsValidMxScaledDot in hipblaslt.cc also accepts F8E4M3FN and F8E5M2 as valid MX input types (without F4). This means a scaled dot with (F8E4M3FN, F8E5M2) inputs and ScaleMode::kBlockScaling would pass validation, but there are no new TYPED_MATMUL entries for (HIP_R_8F_E4M3, HIP_R_8F_E5M2) with block scaling here.
The existing FP8-FP8 TYPED_MATMUL entries (e.g., HIP_R_8F_E4M3, HIP_R_8F_E5M2) from above should already handle this case. Verify that hipBLASLt correctly distinguishes tensor scaling vs block scaling for pure FP8 inputs based on the scale mode attributes set on the matmul descriptor, rather than on the data type alone. If it does, the existing entries are fine.
| for (int64_t dim : dot_dims.lhs_contracting_dimensions()) { | ||
| lhs_scale_k *= lhs_scale_shape.dimensions(dim); | ||
| } | ||
| int64_t rhs_scale_k = 1; | ||
| for (int64_t dim : dot_dims.rhs_contracting_dimensions()) { | ||
| rhs_scale_k *= rhs_scale_shape.dimensions(dim); | ||
| } | ||
| if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 || | ||
| k / rhs_scale_k != 32) { | ||
| VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs=" | ||
| << (lhs_scale_k > 0 ? k / lhs_scale_k : 0) |
There was a problem hiding this comment.
Division k / lhs_scale_k and k / rhs_scale_k are guarded by zero checks (lhs_scale_k == 0), which is good. However, this does integer division, so if k is not a multiple of lhs_scale_k, the result won't be exactly 32 and the check will correctly fail. But consider also checking k % lhs_scale_k == 0 explicitly for a clearer error message, rather than relying on truncation behavior.
| case gpu::ScaleMode::kTensorScaling: { | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| break; | ||
| } | ||
| case gpu::ScaleMode::kBlockScaling: { | ||
| #if TF_ROCM_VERSION >= 70000 | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); |
There was a problem hiding this comment.
Both the kTensorScaling and kBlockScaling branches declare static int64_t dummy_pointer = 0xACEBALL;. Because they are in separate case blocks with their own scopes, this is technically fine (two separate static variables). However, having two identical static variables with the same name at different points in the same function is confusing. Consider extracting the common dummy pointer setup into a helper, or moving the single static int64_t dummy_pointer declaration before the switch and sharing it across cases.
| bool grad_x = 11; | ||
| bool grad_y = 12; | ||
| xla.BlasComputationTypeProto compute_type = 13; | ||
| int32 scale_mode = 14; |
There was a problem hiding this comment.
Adding scale_mode to GemmConfigProto means the autotuner cache key changes. If the autotuner cache is populated with entries that lack this field, they will deserialize with scale_mode = 0 (kNone). This is correct for non-scaled matmuls, but verify that:
- The cache version is still compatible (no version bump needed if the new field is additive with a sensible default), and
- Cache lookups include
scale_modein the key so that a kNone entry is not reused for a kBlockScaling matmul.
| } | ||
| return configs; | ||
| } else if (IsScaledDotFusion(instr)) { | ||
| const HloInstruction* scaled_dot = hlo_query::GetFirstInstructionWithOpcode( | ||
| *instr.fused_instructions_computation(), HloOpcode::kScaledDot); | ||
| TF_RET_CHECK(scaled_dot != nullptr); | ||
|
|
||
| TF_ASSIGN_OR_RETURN( | ||
| std::vector<BlasLt::MatmulAlgorithm> algorithms, | ||
| plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, | ||
| workspace_size)); | ||
| int num_algorithms = algorithms.size(); | ||
| std::vector<std::unique_ptr<BackendConfig>> configs; | ||
| configs.reserve(num_algorithms); | ||
| for (int i = 0; i < num_algorithms; ++i) { | ||
| HipblasLtBackendConfig gemm_key; | ||
| gemm_key.set_algorithm(i); | ||
| gemm_key.set_autotune_workspace_size(workspace_size); | ||
| auto any = std::make_unique<google::protobuf::Any>(); | ||
| any->PackFrom(gemm_key); | ||
| configs.push_back(std::move(any)); | ||
| if (!IsValidMxScaledDot(scaled_dot)) { | ||
| return std::vector<std::unique_ptr<BackendConfig>>(); | ||
| } | ||
|
|
||
| const Shape& lhs_shape = scaled_dot->operand(0)->shape(); | ||
| const Shape& rhs_shape = scaled_dot->operand(1)->shape(); | ||
| const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers(); | ||
| const Shape& output_shape = scaled_dot->shape(); | ||
|
|
||
| auto gemm_config_or = GemmConfig::For( | ||
| lhs_shape, dot_dims.lhs_batch_dimensions(), | ||
| dot_dims.lhs_contracting_dimensions(), rhs_shape, | ||
| dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(), | ||
| output_shape, | ||
| /*alpha_real=*/1.0, /*alpha_imag=*/0.0, /*beta=*/0.0, | ||
| PrecisionConfig::ALG_UNSET, /*algorithm=*/std::nullopt, | ||
| se::blas::kDefaultComputePrecision, /*grad_x=*/false, | ||
| /*grad_y=*/false, /*scale_mode=*/se::gpu::ScaleMode::kBlockScaling, | ||
| target_config().device_description.gpu_compute_capability()); |
There was a problem hiding this comment.
The GetSupportedConfigs path for scaled dot fusions creates a GemmConfig using the shapes from the fused scaled_dot instruction, but the ApplyConfig path replaces the outer fusion instruction with a custom call using the outer instruction's operands. This works correctly because the outer fusion's operands are forwarded directly to the fused computation's parameters (operand 0 of the fusion is parameter 0 inside, etc.).
However, there is an implicit assumption that the fusion has no intermediate operations between its parameters and the scaled_dot (i.e., the fused computation is just the scaled_dot with pass-through parameters). If a future optimization adds a transpose or reshape inside the fusion before the scaled_dot, the shapes from scaled_dot->operand(i)->shape() (which are the parameter shapes) would still match the outer operand shapes, but the correspondence could break if intermediate ops change the shape.
This is fine for now but worth documenting as a constraint.
Re-review Summary (Claude)Re-reviewed the latest diff. All previously raised findings remain applicable — no new issues found beyond existing inline comments. Key areas still worth attention:
No code changes detected since last review that would resolve these items. |
| if (a_scale != nullptr && b_scale != nullptr) { | ||
| std::swap(a_scale, b_scale); |
There was a problem hiding this comment.
The scale pointer swap condition a_scale != nullptr && b_scale != nullptr means the swap is silently skipped when only one scale is non-null. For block-scaled MX matmuls, both scales should always be present, but if a bug or edge case results in only one being set, the operands a and b would be swapped while the single scale stays with the wrong operand, producing silently incorrect results.
Consider either always swapping (unconditionally, after a/b swap) or logging/asserting that both must be non-null when must_swap_operands_ is true and at least one scale is present:
| if (a_scale != nullptr && b_scale != nullptr) { | |
| std::swap(a_scale, b_scale); | |
| if (a_scale != nullptr || b_scale != nullptr) { | |
| std::swap(a_scale, b_scale); | |
| } |
| int32 scale_mode = 21; | ||
| } |
There was a problem hiding this comment.
Using a raw int32 with magic values (0=none, 1=tensor_scaling, 2=block_scaling) sacrifices proto type safety and makes the wire format fragile. A proto enum would provide self-documenting field values, validation against unknown values, and proper pretty-printing in debug logs and HLO dumps.
Similarly, the GemmConfigProto in gpu_blas_lt.proto at field 14 uses the same raw int32 pattern. Both should ideally share a single proto enum definition.
If this is a deliberate temporary choice (e.g., to avoid cross-file proto dependencies), documenting that intent would help.
|
|
||
| TEST_F(HipblasLtMxExecutionTest, MxFp4Correctness) { | ||
| RunMxCorrectnessTest(kMxFp4Hlo, ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5)); | ||
| } |
There was a problem hiding this comment.
FP4 (F4E2M1FN) has only 4 bits of precision (~1.5 mantissa bits), so the tolerances aabs=1e-4, arel=1e-5 are identical to the FP8 tests. This seems overly tight for FP4 and may cause flaky failures depending on the input distribution. Typical FP4 relative errors are on the order of ~5-10%.
Consider using looser tolerances for the FP4 variants (e.g., aabs=0.5, arel=0.1) or documenting why these tight tolerances are appropriate (e.g., the reference also uses FP4 decomposition so the error cancels).
| std::vector<BlasLt::MatmulAlgorithm> algorithms, | ||
| plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, | ||
| workspace_size)); | ||
| int num_algorithms = algorithms.size(); |
There was a problem hiding this comment.
int num_algorithms is a narrowing conversion from size_t. XLA convention is to use int64_t for sizes and indices. The scaled-dot path below (line ~295) correctly uses int64_t with static_cast. This line should be consistent:
| int num_algorithms = algorithms.size(); | |
| int64_t num_algorithms = algorithms.size(); |
| TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); | ||
| HloInstruction* gte = parent->AddInstruction( | ||
| HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0)); | ||
| return parent->ReplaceInstruction(&instr, gte); |
There was a problem hiding this comment.
The ApplyConfig path for scaled dot fusions replaces the fusion instruction with a kCublasLtMatmulMxCallTarget custom call + GTE. This is correct per the ApplyConfig contract, but there is a subtle issue: the instr reference passed to ApplyConfig becomes invalid after ReplaceInstruction deletes it. If the autotuner framework accesses instr after ApplyConfig returns, this would be use-after-free.
Please verify that the autotuner framework does not access the instruction after ApplyConfig returns OkStatus(). The existing TritonBackend::ApplyConfig has the same pattern (it also replaces the instruction), so this may already be safe, but worth confirming.
| bool IsScaledDotFusion(const HloInstruction& instr) { | ||
| if (instr.opcode() != HloOpcode::kFusion) return false; | ||
| auto gpu_config = instr.backend_config<GpuBackendConfig>(); | ||
| if (!gpu_config.ok()) return false; | ||
| if (gpu_config->fusion_backend_config().kind() != kTritonGemmFusionKind) { | ||
| return false; | ||
| } | ||
| return hlo_query::GetFirstInstructionWithOpcode( | ||
| *instr.fused_instructions_computation(), HloOpcode::kScaledDot) != | ||
| nullptr; | ||
| } |
There was a problem hiding this comment.
The IsScaledDotFusion helper function (defined above in anonymous namespace) checks for __triton_gemm fusions containing kScaledDot, but does not verify that the hardware is ROCm. The IsSupported method does add a ROCm guard, but IsScaledDotFusion is also called independently in GetSupportedConfigs and ApplyConfig without a platform guard. This is not currently harmful because all callers are gated by IsSupported, but it means the helper's name is somewhat misleading -- it identifies the HLO pattern rather than a truly supported fusion. A defensive check inside IsScaledDotFusion (or renaming it to HasScaledDotPattern) would clarify intent and prevent future misuse.
| GetShapedSliceForHlo(instr, {instr->shape().tuple_shapes_size() - 1})); | ||
| } | ||
|
|
||
| ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, | ||
| gpublas_lt::AsBlasLtEpilogue(epilogue)); | ||
| Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation( | ||
| instr, ir_emitter_context_->GetNextThunkId()); | ||
| std::string canonical_hlo = instr->ToString( | ||
| HloPrintOptions::Fingerprint().set_print_backend_config(true)); | ||
| auto thunk = std::make_unique<CublasLtMatmulThunk>( | ||
| std::move(thunk_info), std::move(canonical_hlo), std::move(gemm_config), | ||
| blas_lt_epilogue, algorithm, config.autotune_workspace_size(), a, b, c, d, | ||
| /*bias=*/std::nullopt, /*aux=*/std::nullopt, a_scale, b_scale, | ||
| /*c_scale=*/std::nullopt, /*d_scale=*/std::nullopt, | ||
| /*d_amax=*/std::nullopt, workspace_buffer); | ||
| return GetThunkSequence(std::move(thunk)); | ||
| } | ||
|
|
||
| absl::StatusOr<ThunkSequence> ThunkEmitter::EmitConvolutionReorderThunk( | ||
| const HloCustomCallInstruction* instr) { | ||
| bool has_bias = instr->operand_count() > 1; |
There was a problem hiding this comment.
In EmitCublasLtMatmulThunkMx, the DoMatmul path needs the HIPBLASLT_MATMUL_DESC_A_SCALE_MODE / HIPBLASLT_MATMUL_DESC_B_SCALE_MODE attributes set on op_desc_ for block scaling to work. These attributes are only set in GetAlgorithms, not in DoMatmul. This works because op_desc_ is a persistent member of MatmulPlan whose state carries over -- but this is fragile. If the plan is created fresh (e.g., during thunk execution when not autotuning), GetAlgorithms may not be called first, and the scale mode attributes would never be set, leading to incorrect computation.
Please verify that GetAlgorithms is always called before DoMatmul for MX plans, or consider also setting the scale mode attributes in DoMatmul (or in GetMatmulPlan / MatmulDesc::Create).
| case gpu::ScaleMode::kBlockScaling: { | ||
| #if TF_ROCM_VERSION >= 70000 | ||
| static int64_t dummy_pointer = 0xACEBALL; | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), | ||
| HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &dummy_pointer)); | ||
| hipblasLtMatmulMatrixScale_t mx_scale = | ||
| HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale)); | ||
| TF_RETURN_IF_ERROR(SetAttr( | ||
| op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale)); | ||
| #else | ||
| return absl::InternalError("Block scaling requires ROCm >= 7.0"); | ||
| #endif | ||
| break; | ||
| } |
There was a problem hiding this comment.
The block scaling SCALE_MODE attributes (HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_DESC_B_SCALE_MODE) are only set here in GetAlgorithms, not in DoMatmul. While op_desc_ is a persistent member whose state carries over, this creates an implicit dependency: GetAlgorithms must be called before DoMatmul for the scale mode to be correctly configured. If a plan is constructed and DoMatmul is called without first calling GetAlgorithms (possible in non-autotuning paths), the scale mode would default to none, producing silently incorrect results.
Consider setting these attributes in MatmulDesc::Create (alongside the constructor) or in the plan construction path in GetMatmulPlan, so the scale mode is always configured regardless of call order.
| } | ||
| if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 || | ||
| k / rhs_scale_k != 32) { | ||
| VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs=" | ||
| << (lhs_scale_k > 0 ? k / lhs_scale_k : 0) | ||
| << " rhs=" << (rhs_scale_k > 0 ? k / rhs_scale_k : 0); | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| bool IsScaledDotFusion(const HloInstruction& instr) { | ||
| if (instr.opcode() != HloOpcode::kFusion) return false; |
There was a problem hiding this comment.
IsValidMxScaledDot performs integer division k / lhs_scale_k and k / rhs_scale_k to check that the block size is 32. If k is not evenly divisible by lhs_scale_k, this division truncates and could incorrectly match (or miss). For example, if k=63 and lhs_scale_k=2, then 63/2=31 which would fail, but k=64 and lhs_scale_k=2 gives 64/2=32 which passes. It works for the current "well-formed" inputs, but adding a check k % lhs_scale_k == 0 before the division would make it more robust against malformed shapes.
SummaryThis PR adds ROCm/hipBLASLt support for block-scaled matrix multiplication using MX (Microscaling) formats (FP8 and FP4) on gfx950+ hardware. It introduces a new Findings
Generated with Claude Code |
|
Re-review summary (2026-03-13): Re-examined the diff — no code changes since last review (only a BUILD file dependency fix). All prior inline findings remain open and unaddressed. No new issues found. |
No description provided.