Skip to content

Dev rocm scaled dot#655

Open
ScXfjiang wants to merge 11 commits intomainfrom
dev_rocm_scaled_dot
Open

Dev rocm scaled dot#655
ScXfjiang wants to merge 11 commits intomainfrom
dev_rocm_scaled_dot

Conversation

@ScXfjiang
Copy link

No description provided.

Comment on lines +168 to +170
rhs_scale_k *= rhs_scale_shape.dimensions(dim);
}
if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 ||
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: Integer division may falsely accept non-aligned block sizes

The block-size check k / lhs_scale_k != 32 uses integer division. If k is not perfectly divisible by lhs_scale_k, the integer quotient could still equal 32 even though the actual block size is fractional.

For example: k=65, lhs_scale_k=265/2 = 32 (integer), but the real block size is 32.5 elements.

Consider adding an explicit divisibility check:

Suggested change
rhs_scale_k *= rhs_scale_shape.dimensions(dim);
}
if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 ||
if (lhs_scale_k == 0 || k % lhs_scale_k != 0 || k / lhs_scale_k != 32 ||
rhs_scale_k == 0 || k % rhs_scale_k != 0 || k / rhs_scale_k != 32) {

Comment on lines +370 to +408
} else if (IsScaledDotFusion(instr)) {
const HloInstruction* scaled_dot = hlo_query::GetFirstInstructionWithOpcode(
*instr.fused_instructions_computation(), HloOpcode::kScaledDot);
TF_RET_CHECK(scaled_dot != nullptr);
HloComputation* parent = instr.parent();

TF_RET_CHECK(instr.operand_count() == 4);
HloInstruction* lhs = instr.mutable_operand(0);
HloInstruction* rhs = instr.mutable_operand(1);
HloInstruction* lhs_scale = instr.mutable_operand(2);
HloInstruction* rhs_scale = instr.mutable_operand(3);

const Shape& result_shape = scaled_dot->shape();
int64_t workspace_size = gemm_key.autotune_workspace_size();
Shape workspace_shape = ShapeUtil::MakeShape(S8, {workspace_size});
Shape output_shape =
ShapeUtil::MakeTupleShape({result_shape, workspace_shape});

GpuBackendConfig gpu_backend_config;
GemmBackendConfig& gemm_config =
*gpu_backend_config.mutable_gemm_backend_config();
*gemm_config.mutable_dot_dimension_numbers() =
scaled_dot->dot_dimension_numbers();
gemm_config.set_alpha_real(1.0);
gemm_config.set_alpha_imag(0.0);
gemm_config.set_beta(0.0);
gemm_config.set_scale_mode(
static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling));
gemm_config.set_selected_algorithm(gemm_key.algorithm());
gemm_config.set_autotune_workspace_size(workspace_size);

HloInstruction* custom_call =
parent->AddInstruction(HloInstruction::CreateCustomCall(
output_shape, {lhs, rhs, lhs_scale, rhs_scale},
kCublasLtMatmulMxCallTarget));
TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
HloInstruction* gte = parent->AddInstruction(
HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0));
return parent->ReplaceInstruction(&instr, gte);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concern: Graph transformation in ApplyConfig is unusual

ApplyConfig typically only sets backend config on the existing instruction, but here it replaces the fusion instruction with a new kCublasLtMatmulMxCallTarget custom call via parent->ReplaceInstruction(&instr, gte). After this call, &instr becomes a dangling reference.

Two concerns:

  1. Verify the autotuner framework handles instruction replacement during ApplyConfig (not just config changes).
  2. The TF_RET_CHECK(instr.operand_count() == 4) at line 376 will crash with a fatal error if the count is wrong, but IsScaledDotFusion doesn't verify operand count. Consider moving this check into IsValidMxScaledDot or IsScaledDotFusion so that GetSupportedConfigs returns empty instead of crashing in ApplyConfig.

Comment on lines 114 to +116
int64 autotune_workspace_size = 20;

// Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider using a protobuf enum for type safety

Using a raw int32 with static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling) throughout the codebase means that if the C++ enum values change, the proto wire values silently change. A proper protobuf enum type would provide type safety, self-documenting wire format, and better backward compatibility.

Not blocking, but worth considering for a follow-up.

Comment on lines 1415 to +1417

gemm_backend_config.set_scale_mode(
static_cast<int32_t>(se::gpu::ScaleMode::kTensorScaling));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Behavioral change for existing FP8 matmuls

This sets scale_mode=kTensorScaling for existing FP8 matmuls that previously had scale_mode unset (defaulting to kNone). Downstream in hip_blas_lt.cc:GetAlgorithms, the code now uses a switch(scale_mode) instead of the old IsFP8() lambda check.

If any existing FP8 GemmBackendConfig protos are read from cache (rather than re-generated via the gemm rewriter), they would have scale_mode=0 (kNone) and the switch would skip setting scale pointers, causing algorithm lookup to fail.

Can you confirm existing FP8 matmul configs are always regenerated (not cached) so the new field is always present?

Comment on lines +418 to 420
DeviceAddressBase a_scale = args.a_scale, b_scale = args.b_scale;
if (must_swap_operands_) {
std::swap(a, b);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Inconsistency with #else branch (ROCm < 6.0)

The swapped a_scale/b_scale locals are correctly used in the #if TF_ROCM_VERSION >= 60000 branch, but the #else branch (~line 488 in the original file) still refers to args.a_scale and args.b_scale. While logically equivalent for the null check, this inconsistency makes the code harder to reason about. Consider updating the #else branch to use the swapped locals too.

Comment on lines +690 to +691
ASSIGN_OR_RETURN(ShapedSlice c, GetShapedSliceForHlo(instr, output_index));
ASSIGN_OR_RETURN(ShapedSlice d, GetShapedSliceForHlo(instr, output_index));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: c and d alias the same buffer

Both c and d are set to the same output slice. This works because beta=0 (so the c term in alpha*A*B + beta*C vanishes), but it's fragile — if a future change enables non-zero beta for MX matmul, the output would be aliased with the accumulation input. Consider adding a comment clarifying this is intentional because beta is always 0 for MX.

Comment on lines 620 to +627

#if TF_ROCM_VERSION >= 70000
// MX FP4 (F4E2M1FN) type combinations
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: OCP vs FNUZ FP8 type mapping

The TYPED_MATMUL entries use HIP_R_8F_E4M3 and HIP_R_8F_E5M2 (OCP/non-FNUZ variants) as mixed-type partners with HIP_R_4F_E2M1. The validator in hipblaslt.cc accepts F8E4M3FN and F8E5M2 as valid input types.

Older ROCm versions used FNUZ variants (HIP_R_8F_E4M3_FNUZ / HIP_R_8F_E5M2_FNUZ). Can you confirm that HIP_R_8F_E4M3 / HIP_R_8F_E5M2 correctly map to XLA's F8E4M3FN / F8E5M2 types on the target ROCm 7.0+ platform? The AsHipblasDataType function in hip_blas_utils.cc should be consistent with these entries.

Comment on lines +122 to +128
for (int64_t dim : dot_dims.lhs_batch_dimensions()) {
batch_size *= lhs_shape.dimensions(dim);
}
if (batch_size != 1) {
VLOG(2) << "hipBLASLt MX: batch_size > 1 not supported, got " << batch_size;
return false;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Batch restriction not documented

The batch_size != 1 check rejects batched MX operations, but it only computes batch_size from lhs batch dimensions. Consider adding a brief comment explaining this is a hipBLASLt limitation (or a TODO to revisit when hipBLASLt gains batch support for MX). The test kMxFp8BatchedHlo passes [1,32,256] with batch_size=1, which validates the path but doesn't exercise real batching.

Comment on lines +254 to 287
switch (op_desc_.scale_mode()) {
case gpu::ScaleMode::kNone:
break;
case gpu::ScaleMode::kTensorScaling: {
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
break;
}
case gpu::ScaleMode::kBlockScaling: {
#if TF_ROCM_VERSION >= 70000
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
hipblasLtMatmulMatrixScale_t mx_scale =
HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale));
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale));
#else
return absl::InternalError("Block scaling requires ROCm >= 7.0");
#endif
break;
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code auto-detected FP8 inputs via IsFP8(a_desc_) && IsFP8(b_desc_) and unconditionally set dummy scale pointers. Now this is gated on scale_mode == kTensorScaling.

Any pre-existing serialized/cached FP8 matmul configs (e.g., from an autotune cache on disk) that were created before this PR will have scale_mode = 0 (kNone) in their proto, because the scale_mode field didn't exist and proto defaults to 0. When those cached configs are deserialized and used, GetAlgorithms() will skip setting the dummy a/b scale pointers, and algorithm enumeration may fail or return wrong algorithms.

Consider either:

  1. Treating FP8 + kNone the same as kTensorScaling for backward compatibility (i.e., fall through from kNone to kTensorScaling when the input types are FP8), or
  2. Bumping the autotune cache version to invalidate old caches.

#include "xla/service/compiler.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/ir_emission_utils.h"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir_emission_utils.h is included here but the only symbol used from it is kTritonGemmFusionKind. Consider whether a more targeted include (or forward declaration) would be preferable to pulling in the full ir_emission_utils.h header, which is a large header with many dependencies.

Comment on lines +87 to +93
bool IsValidMxScaledDot(const HloInstruction* scaled_dot) {
const Shape& lhs_shape = scaled_dot->operand(0)->shape();
const Shape& rhs_shape = scaled_dot->operand(1)->shape();
const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape();
const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape();
const Shape& output_shape = scaled_dot->shape();
const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IsValidMxScaledDot function accesses scaled_dot->operand(2) and scaled_dot->operand(3) without first verifying that scaled_dot->operand_count() >= 4. If a malformed kScaledDot instruction is encountered with fewer operands, this would be an out-of-bounds access. Consider adding a guard:

Suggested change
bool IsValidMxScaledDot(const HloInstruction* scaled_dot) {
const Shape& lhs_shape = scaled_dot->operand(0)->shape();
const Shape& rhs_shape = scaled_dot->operand(1)->shape();
const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape();
const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape();
const Shape& output_shape = scaled_dot->shape();
const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers();
bool IsValidMxScaledDot(const HloInstruction* scaled_dot) {
if (scaled_dot->operand_count() < 4) {
VLOG(2) << "hipBLASLt MX: scaled-dot must have 4 operands, got "
<< scaled_dot->operand_count();
return false;
}
const Shape& lhs_shape = scaled_dot->operand(0)->shape();
const Shape& rhs_shape = scaled_dot->operand(1)->shape();
const Shape& lhs_scale_shape = scaled_dot->operand(2)->shape();
const Shape& rhs_scale_shape = scaled_dot->operand(3)->shape();
const Shape& output_shape = scaled_dot->shape();
const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers();

Comment on lines +170 to +176
if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 ||
k / rhs_scale_k != 32) {
VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs="
<< (lhs_scale_k > 0 ? k / lhs_scale_k : 0)
<< " rhs=" << (rhs_scale_k > 0 ? k / rhs_scale_k : 0);
return false;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero risk: if lhs_scale_k or rhs_scale_k is zero, k / lhs_scale_k would trigger undefined behavior before the != 32 check. The lhs_scale_k == 0 guard short-circuits via ||, so k / lhs_scale_k is only evaluated when lhs_scale_k != 0 -- this is correct. However, the same is not true for rhs_scale_k: the expression rhs_scale_k == 0 || k / rhs_scale_k != 32 is also correct due to short-circuit evaluation.

On second look, the logic is fine due to || short-circuiting. Disregard.

Comment on lines +401 to +408
HloInstruction* custom_call =
parent->AddInstruction(HloInstruction::CreateCustomCall(
output_shape, {lhs, rhs, lhs_scale, rhs_scale},
kCublasLtMatmulMxCallTarget));
TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
HloInstruction* gte = parent->AddInstruction(
HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0));
return parent->ReplaceInstruction(&instr, gte);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ApplyConfig replaces the fusion instruction instr with a GetTupleElement(custom_call, 0). The original fusion instruction is removed from the computation. If instr was the root instruction of the entry computation, ReplaceInstruction changes the root to the GTE. This means the module's result layout should be updated, similar to how the IsCublasLtMatmul path updates the result layout (lines 360-365). The MX path does not do this.

Additionally, the dead fusion computation (the one that contained the kScaledDot) is not explicitly cleaned up. While HLO DCE may eventually remove it, this could leave orphaned computations in the module during autotuning.

Comment on lines +115 to +117

// Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX).
int32 scale_mode = 21;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scale_mode field is stored as int32 in the proto rather than a proper proto enum. This means there is no validation at the proto level -- any int32 value can be deserialized into scale_mode. The code then does static_cast<ScaleMode>(config.scale_mode()) in multiple places without bounds checking, which could create invalid enum values if the proto contains unexpected data (e.g., from a newer version).

Consider using a proto enum type here, or adding a validation check at deserialization points.

Comment on lines +119 to +121
// A call to hipBLASLt for block-scaled matrix multiplication in MX formats.
inline constexpr absl::string_view kCublasLtMatmulMxCallTarget =
"__cublas$lt$matmul$mx";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming convention kCublasLtMatmulMxCallTarget = "__cublas$lt$matmul$mx" uses the "cublas" prefix, but the comment says "Calls into hipBLASLt." This is a ROCm-only feature that will never call into cuBLAS/cuBLASLt. Using the cublas prefix here could be confusing. Consider whether a different naming convention would be clearer (though I acknowledge this follows the existing pattern of kCublasLtMatmulF8CallTarget which is also used for hipBLASLt on ROCm).

Comment on lines +690 to +691
ASSIGN_OR_RETURN(ShapedSlice c, GetShapedSliceForHlo(instr, output_index));
ASSIGN_OR_RETURN(ShapedSlice d, GetShapedSliceForHlo(instr, output_index));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MX thunk emitter uses ASSIGN_OR_RETURN (without TF_ prefix) for ShapedSlice assignments, which is consistent with nearby code. However, the mix of TF_ASSIGN_OR_RETURN (lines 675-676) and ASSIGN_OR_RETURN (lines 683-691) within the same function is worth noting for consistency.

More importantly: lines 690-691 use the same output_index for both c and d, making c and d point to the same buffer. This is intentional for MX (where beta=0, so C is unused), but it's worth a comment explaining why c aliases d here, similar to how the F8 path (line 607) handles this with a comment about has_matrix_bias.

Comment on lines +181 to +191
bool IsScaledDotFusion(const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) return false;
auto gpu_config = instr.backend_config<GpuBackendConfig>();
if (!gpu_config.ok()) return false;
if (gpu_config->fusion_backend_config().kind() != kTritonGemmFusionKind) {
return false;
}
return hlo_query::GetFirstInstructionWithOpcode(
*instr.fused_instructions_computation(), HloOpcode::kScaledDot) !=
nullptr;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IsScaledDotFusion function calls instr.backend_config<GpuBackendConfig>() and silently returns false if it fails. This is fine for a predicate. However, it also uses .value() semantics implicitly by accessing gpu_config->fusion_backend_config() without checking .ok() -- actually it does check .ok() on line 89, so this is correct.

One concern: this function is called from IsSupported, GetSupportedConfigs, ApplyConfig, and GetDefaultConfig. Each time it re-parses the backend config proto. Consider caching or restructuring to avoid redundant proto parsing.

Comment on lines 419 to +423
if (must_swap_operands_) {
std::swap(a, b);
if (a_scale != nullptr && b_scale != nullptr) {
std::swap(a_scale, b_scale);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When must_swap_operands_ is true, the scale pointers are also swapped. However, this swap only happens when BOTH a_scale and b_scale are non-null. If only one is non-null (which shouldn't happen for MX, but could for other configurations), the scale would NOT be swapped while the matrices ARE swapped, leading to incorrect scale application. Consider swapping unconditionally when must_swap_operands_ is true, or at least asserting that both are null or both are non-null when swapping is required.

Comment on lines +719 to +788
absl::Status HandleRaggedDot(HloInstruction* instr) override {
if (!IsGpublasLtSupportedGroupedMatMul(*instr)) {
return absl::OkStatus();
}
HloRaggedDotInstruction* ragged_dot =
DynCast<HloRaggedDotInstruction>(instr);
if (ragged_dot == nullptr) {
return absl::OkStatus();
}
const auto& ragged_dims = ragged_dot->ragged_dot_dimension_numbers();
const auto& dot_dims = ragged_dims.dot_dimension_numbers();
if (ragged_dims.lhs_ragged_dimensions().size() != 1) {
return absl::UnimplementedError("lhs_ragged_dimensions must have size 1");
}
int lhs_ragged_dim = ragged_dims.lhs_ragged_dimensions(0);

auto isLhsRaggedDimInContractingDim = [](int lhs_ragged_dim,
const DotDimensionNumbers& dnums) {
return std::any_of(dnums.lhs_contracting_dimensions().begin(),
dnums.lhs_contracting_dimensions().end(),
[&](auto dim) { return dim == lhs_ragged_dim; });
};

auto isLhsRaggedDimInBatchDim = [](int lhs_ragged_dim,
const DotDimensionNumbers& dnums) {
return std::any_of(dnums.lhs_batch_dimensions().begin(),
dnums.lhs_batch_dimensions().end(),
[&](auto dim) { return dim == lhs_ragged_dim; });
};

if (!isLhsRaggedDimInContractingDim(lhs_ragged_dim, dot_dims) &&
!isLhsRaggedDimInBatchDim(lhs_ragged_dim, dot_dims) &&
ragged_dims.rhs_group_dimensions().size() != 1) {
return absl::UnimplementedError(
"rhs_group_dimensions must have size equal to 1 when lhs ragged "
"dimension is a non-contracting dimension");
}
HloInstruction* grouped_gemm_call =
instr->AddInstruction(HloInstruction::CreateCustomCall(
ragged_dot->shape(), ragged_dot->mutable_operands(),
gpu::kCublasLtGroupedMatmulCallTarget));

// Create a GroupedGemmBackendConfig based on the instruction.
TF_ASSIGN_OR_RETURN(
gpu::GpuBackendConfig gpu_backend_config,
grouped_gemm_call->backend_config<gpu::GpuBackendConfig>());
GroupedGemmBackendConfig& grouped_gemm_backend_config =
*gpu_backend_config.mutable_grouped_gemm_backend_config();
RaggedDotDimensionNumbers& ragged_dot_dimension_numbers =
*grouped_gemm_backend_config.mutable_ragged_dot_dimension_numbers();
ragged_dot_dimension_numbers = ragged_dot->ragged_dot_dimension_numbers();

// Create a GemmBackendConfig based on the instruction.
GemmBackendConfig& gemm_backend_config =
*grouped_gemm_backend_config.mutable_gemm_backend_config();
gemm_backend_config.set_alpha_real(1.0);
gemm_backend_config.set_alpha_imag(0.0);
gemm_backend_config.set_beta(0.0);
*gemm_backend_config.mutable_dot_dimension_numbers() = dot_dims;

auto attributes = instr->frontend_attributes().map();
gemm_backend_config.set_grad_x(attributes["grad_x"] == "true");
gemm_backend_config.set_grad_y(attributes["grad_y"] == "true");

TF_RETURN_IF_ERROR(
grouped_gemm_call->set_backend_config(gpu_backend_config));

TF_RETURN_IF_ERROR(ReplaceInstruction(instr, grouped_gemm_call));
return absl::OkStatus();
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: No thunk emitter for kCublasLtGroupedMatmulCallTarget

HandleRaggedDot rewrites ragged dots into a kCublasLtGroupedMatmulCallTarget custom call, but thunk_emitter.cc has no handler for this target — EmitCustomCallSwitch does not recognize it. Any ragged dot rewritten by this handler will fail at code generation with an "Unrecognized custom call target" error.

The workspace rewriter (GemmWorkspaceRewriteVisitor) does handle this target (line ~1820), so workspace allocation succeeds, but the resulting instruction cannot be emitted.

Is this intentionally gated behind a follow-up PR that adds the thunk emitter? If so, consider adding a comment or a TODO here noting the dependency.

Comment on lines +77 to +92
bool IsGpublasLtSupportedGroupedMatMul(const HloInstruction& instr) {
if (instr.opcode() == HloOpcode::kRaggedDot) {
switch (instr.shape().element_type()) {
// Only float 16 and bf 16 are supported by HipBlasLt GroupGemm
case F16:
case BF16:
return (((instr.operand(0)->shape().element_type() == F16) ||
(instr.operand(0)->shape().element_type() == BF16)) &&
((instr.operand(1)->shape().element_type() == F16) ||
(instr.operand(1)->shape().element_type() == BF16)));
default:
return false;
}
}
return false;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: Missing platform guard — could trigger on CUDA

IsGpublasLtSupportedGroupedMatMul does not check whether the target platform is ROCm. If this returns true on CUDA, HandleRaggedDot in gemm_rewriter.cc will convert the ragged dot to a kCublasLtGroupedMatmulCallTarget custom call on CUDA, where there is no emitter for it, causing a compilation failure.

This needs a platform guard (e.g., checking GpuComputeCapability is RocmComputeCapability) or documentation that it is ROCm-only and the caller must gate it.

Comment on lines +80 to +86
// Only float 16 and bf 16 are supported by HipBlasLt GroupGemm
case F16:
case BF16:
return (((instr.operand(0)->shape().element_type() == F16) ||
(instr.operand(0)->shape().element_type() == BF16)) &&
((instr.operand(1)->shape().element_type() == F16) ||
(instr.operand(1)->shape().element_type() == BF16)));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Mixed input types not validated

The comment says "Only float 16 and bf 16 are supported" but the check allows mixed input types (e.g., lhs=F16, rhs=BF16). Does hipBLASLt grouped GEMM actually support mixed F16/BF16 inputs? If not, this should also verify that lhs and rhs have matching element types.

@claude
Copy link

claude bot commented Mar 11, 2026

Re-review Summary\n\nThis is a follow-up review. 18 previous inline comments remain open (issues not yet addressed). 8 new findings have been posted inline, focused on the grouped GEMM (ragged dot) additions.\n\nKey new findings:\n- Missing thunk emitter for kCublasLtGroupedMatmulCallTarget -- ragged dots rewritten by HandleRaggedDot cannot be emitted, will fail at codegen\n- Missing platform guard in IsGpublasLtSupportedGroupedMatMul -- could incorrectly trigger on CUDA\n- Magic constant kUserArgsSizeBytes = 196 undocumented\n- Style nits: std::any_of to absl::c_any_of, int to int64_t, typos in comments, oneof with single variant\n- FP4 test tolerance question\n\nGenerated with Claude Code

Comment on lines +254 to 287
switch (op_desc_.scale_mode()) {
case gpu::ScaleMode::kNone:
break;
case gpu::ScaleMode::kTensorScaling: {
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
break;
}
case gpu::ScaleMode::kBlockScaling: {
#if TF_ROCM_VERSION >= 70000
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
hipblasLtMatmulMatrixScale_t mx_scale =
HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale));
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale));
#else
return absl::InternalError("Block scaling requires ROCm >= 7.0");
#endif
break;
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code detected FP8 from the matrix layout data types (IsFP8(a_desc_) && IsFP8(b_desc_)), which worked regardless of whether scale_mode was set. The new code relies entirely on op_desc_.scale_mode() being set correctly.

For deserialized or cached HLO modules that predate this PR, the scale_mode proto field will default to 0 (kNone), even though the matmul is FP8 with tensor scaling. This means GetAlgorithms will skip setting the dummy scale pointers, and hipblasLtMatmulAlgoGetHeuristic may return zero algorithms -- breaking autotuning for existing FP8 matmuls.

Consider either:

  1. Falling back to the old IsFP8 check when scale_mode == kNone (e.g., treat kNone + FP8 layouts as kTensorScaling), or
  2. Adding a migration step that sets scale_mode on deserialized F8 custom calls before autotuning.

Comment on lines +116 to +117
// Scale mode: 0=none, 1=tensor_scaling (fp8), 2=block_scaling (MX).
int32 scale_mode = 21;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a raw int32 with a comment documenting the enum values (0=none, 1=tensor_scaling, 2=block_scaling) creates a tight coupling between this proto and the C++ ScaleMode enum that is not enforced at the proto schema level. If either side changes independently, the values will silently mismatch.

Consider defining a proper proto enum:

enum ScaleMode {
  SCALE_MODE_NONE = 0;
  SCALE_MODE_TENSOR_SCALING = 1;
  SCALE_MODE_BLOCK_SCALING = 2;
}
ScaleMode scale_mode = 21;

This would give you wire compatibility (the encoding is identical) while providing self-documenting proto definitions and validation.

Comment on lines +190 to +201
nullptr;
}

} // namespace

bool HipblasLtBackend::IsSupported(const HloInstruction& instr) {
return IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr);
if (IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr)) {
return true;
}
if (IsScaledDotFusion(instr)) {
const auto& gpu_cc =
target_config().device_description.gpu_compute_capability();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsScaledDotFusion checks whether a fusion's backend_config has kind == kTritonGemmFusionKind and contains a kScaledDot instruction. However, this function is called from IsSupported(), which does not guard against CUDA platforms. IsSupported() does subsequently check for rocm_compute_capability(), but IsScaledDotFusion is also used directly in GetSupportedConfigs and ApplyConfig, where the platform check is not repeated.

If HipblasLtBackend is only ever instantiated on ROCm (which appears to be the case today), this is safe. But the function name does not make this platform restriction obvious. Consider either:

  • Adding a comment documenting the platform assumption, or
  • Adding a DCHECK that the device is ROCm inside IsScaledDotFusion.

Comment on lines +388 to +399
GpuBackendConfig gpu_backend_config;
GemmBackendConfig& gemm_config =
*gpu_backend_config.mutable_gemm_backend_config();
*gemm_config.mutable_dot_dimension_numbers() =
scaled_dot->dot_dimension_numbers();
gemm_config.set_alpha_real(1.0);
gemm_config.set_alpha_imag(0.0);
gemm_config.set_beta(0.0);
gemm_config.set_scale_mode(
static_cast<int32_t>(se::gpu::ScaleMode::kBlockScaling));
gemm_config.set_selected_algorithm(gemm_key.algorithm());
gemm_config.set_autotune_workspace_size(workspace_size);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ApplyConfig for the IsScaledDotFusion branch, the fusion instruction is replaced with a kCublasLtMatmulMx custom call + GTE. However, the precision_config from the original scaled_dot is not propagated to the new GemmBackendConfig. While the default PrecisionConfig (ALG_UNSET, DEFAULT precision) may be adequate here, silently dropping it could lead to subtle differences in compute behavior if the user's HLO specifies non-default precision settings.

Consider copying the precision config:

*gemm_config.mutable_precision_config() = scaled_dot->precision_config();

Comment on lines +118 to +127
return false;
}

int64_t batch_size = 1;
for (int64_t dim : dot_dims.lhs_batch_dimensions()) {
batch_size *= lhs_shape.dimensions(dim);
}
if (batch_size != 1) {
VLOG(2) << "hipBLASLt MX: batch_size > 1 not supported, got " << batch_size;
return false;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsValidMxScaledDot rejects batch_size > 1 but allows batch_size == 1 (with batch dimensions present). The test MxFp8BatchedCorrectness exercises this with shape [1,32,256] and lhs_batch_dims={0}.

If the intent is that hipBLASLt MX does not support batched operations at all, this check should reject any non-empty lhs_batch_dimensions(). If the intent is only to reject multi-batch cases, the current code is correct but the VLOG message "batch_size > 1 not supported" is misleading -- it should clarify that batch_size=1 (degenerate batch) is allowed.

Also, this only checks lhs_batch_dimensions. If LHS and RHS have different batch dimensions (which would be unusual but valid in the proto), the RHS batch size is not validated.

Comment on lines +418 to +421
DeviceAddressBase a_scale = args.a_scale, b_scale = args.b_scale;
if (must_swap_operands_) {
std::swap(a, b);
if (a_scale != nullptr && b_scale != nullptr) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scale pointers (a_scale, b_scale) are now swapped together with the operand pointers when must_swap_operands_ is true. This is correct for maintaining the logical correspondence between an operand and its scale.

However, the swap is guarded by a_scale != nullptr && b_scale != nullptr. If only one scale is non-null (which shouldn't happen for MX but could happen in edge cases), the swap would be silently skipped, leading to a scale being applied to the wrong operand. Consider using || instead of &&, or adding a TF_RET_CHECK that both are null or both are non-null.

Comment on lines +334 to 344
auto IsScaledType = [](xla::PrimitiveType dtype) {
return xla::primitive_util::IsF8Type(dtype) || dtype == xla::F4E2M1FN;
};
if (IsScaledType(lhs_layout.dtype) &&
lhs_layout.order == gpu::MatrixLayout::Order::kColumnMajor) {
return xla::Internal("The F8 LHS must be column-major");
return xla::Internal("The F8/MX LHS must be row-major");
}
if (xla::primitive_util::IsF8Type(rhs_layout.dtype) &&
if (IsScaledType(rhs_layout.dtype) &&
rhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) {
return xla::Internal("The F8 RHS must be row-major");
return xla::Internal("The F8/MX RHS must be column-major");
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IsScaledType lambda includes F4E2M1FN alongside F8 types. The corresponding error messages were updated from "The F8 LHS must be column-major" to "The F8/MX LHS must be row-major" (and similarly for RHS). The corrected phrasing ("row-major" / "column-major") is good and fixes what appears to have been an incorrect error message in the original code.

However, note that this also changes the validation behavior for F4E2M1FN types: previously they would pass through without any layout check (since IsF8Type didn't match them). Now they are subject to the same row-major LHS / column-major RHS constraint. Make sure this is intentional and that hipBLASLt actually requires this layout for MX types.


TEST_F(HipblasLtMxExecutionTest, MxFp4BatchedCorrectness) {
RunMxCorrectnessTest(kMxFp4BatchedHlo,
ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FP4 tests use the same error tolerances as FP8 tests (aabs=1e-4, arel=1e-5). FP4 (F4E2M1FN) has only 2 exponent bits and 1 mantissa bit, so it is significantly less precise than FP8. These tight tolerances may cause flaky failures depending on matrix values.

If the reference path decomposes the scaled dot into equivalent FP32 arithmetic, the comparison should still be close because both paths effectively compute the same low-precision operation. But if there are any differences in intermediate precision or rounding, FP4 would amplify them. Consider using a slightly wider tolerance for FP4 tests, or document why the tight tolerance is expected to hold.

Comment on lines +119 to +121
// A call to hipBLASLt for block-scaled matrix multiplication in MX formats.
inline constexpr absl::string_view kCublasLtMatmulMxCallTarget =
"__cublas$lt$matmul$mx";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming convention kCublasLtMatmulMxCallTarget = "__cublas$lt$matmul$mx" uses the "cublas" prefix, but this target is specific to hipBLASLt on ROCm. This is consistent with the existing pattern where kCublasLtMatmulCallTarget and kCublasLtMatmulF8CallTarget are also used on ROCm (the naming convention follows the CUDA origin). The comment correctly clarifies "Calls into hipBLASLt." Just noting this for awareness -- the naming may cause confusion if CUDA ever gets native MX support via cuBLASLt.

Comment on lines +2572 to +2574
if (IsCublasLtMatmulMx(*hlo)) {
return EmitCublasLtMatmulThunkMx(custom_call);
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new EmitCublasLtMatmulThunkMx is registered in EmitCustomCallSwitch, which is the correct integration point. However, IsCublasLtMatmulMx will match on CUDA platforms too, since it only checks the custom call target string. If a CUDA device somehow encounters a __cublas$lt$matmul$mx custom call (e.g., from a portable saved model), this will attempt to emit a CublasLtMatmulThunk on CUDA where ScaleMode::kBlockScaling is not handled in cuda_blas_lt.cc. This would likely fail at GetMatmulPlan time, but the error message would be confusing.

Consider adding a platform check or at minimum a clear error message for the CUDA case.

Comment on lines +621 to 675
#if TF_ROCM_VERSION >= 70000
// MX FP4 (F4E2M1FN) type combinations
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F)

TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E4M3, HIP_R_16BF, HIP_R_16F)

TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_4F_E2M1, HIP_R_8F_E5M2, HIP_R_16BF, HIP_R_16F)

TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E4M3, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F)

TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_32F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16F)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16F, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16BF)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_32F)
TYPED_MATMUL(float, HIP_R_8F_E5M2, HIP_R_4F_E2M1, HIP_R_16BF, HIP_R_16F)
#endif

// Other data types:
TYPED_MATMUL(float, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF, HIP_R_16BF)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MX TYPED_MATMUL entries cover F4E2M1FN crossed with {F4E2M1FN, F8E4M3, F8E5M2} and vice versa. But the IsValidMxScaledDot in hipblaslt.cc also accepts F8E4M3FN and F8E5M2 as valid MX input types (without F4). This means a scaled dot with (F8E4M3FN, F8E5M2) inputs and ScaleMode::kBlockScaling would pass validation, but there are no new TYPED_MATMUL entries for (HIP_R_8F_E4M3, HIP_R_8F_E5M2) with block scaling here.

The existing FP8-FP8 TYPED_MATMUL entries (e.g., HIP_R_8F_E4M3, HIP_R_8F_E5M2) from above should already handle this case. Verify that hipBLASLt correctly distinguishes tensor scaling vs block scaling for pure FP8 inputs based on the scale mode attributes set on the matmul descriptor, rather than on the data type alone. If it does, the existing entries are fine.

Comment on lines +163 to +173
for (int64_t dim : dot_dims.lhs_contracting_dimensions()) {
lhs_scale_k *= lhs_scale_shape.dimensions(dim);
}
int64_t rhs_scale_k = 1;
for (int64_t dim : dot_dims.rhs_contracting_dimensions()) {
rhs_scale_k *= rhs_scale_shape.dimensions(dim);
}
if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 ||
k / rhs_scale_k != 32) {
VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs="
<< (lhs_scale_k > 0 ? k / lhs_scale_k : 0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division k / lhs_scale_k and k / rhs_scale_k are guarded by zero checks (lhs_scale_k == 0), which is good. However, this does integer division, so if k is not a multiple of lhs_scale_k, the result won't be exactly 32 and the check will correctly fail. But consider also checking k % lhs_scale_k == 0 explicitly for a clearer error message, rather than relying on truncation behavior.

Comment on lines +257 to +275
case gpu::ScaleMode::kTensorScaling: {
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
break;
}
case gpu::ScaleMode::kBlockScaling: {
#if TF_ROCM_VERSION >= 70000
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the kTensorScaling and kBlockScaling branches declare static int64_t dummy_pointer = 0xACEBALL;. Because they are in separate case blocks with their own scopes, this is technically fine (two separate static variables). However, having two identical static variables with the same name at different points in the same function is confusing. Consider extracting the common dummy pointer setup into a helper, or moving the single static int64_t dummy_pointer declaration before the switch and sharing it across cases.

bool grad_x = 11;
bool grad_y = 12;
xla.BlasComputationTypeProto compute_type = 13;
int32 scale_mode = 14;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding scale_mode to GemmConfigProto means the autotuner cache key changes. If the autotuner cache is populated with entries that lack this field, they will deserialize with scale_mode = 0 (kNone). This is correct for non-scaled matmuls, but verify that:

  1. The cache version is still compatible (no version bump needed if the new field is additive with a sensible default), and
  2. Cache lookups include scale_mode in the key so that a kNone entry is not reused for a kBlockScaling matmul.

Comment on lines +256 to +281
}
return configs;
} else if (IsScaledDotFusion(instr)) {
const HloInstruction* scaled_dot = hlo_query::GetFirstInstructionWithOpcode(
*instr.fused_instructions_computation(), HloOpcode::kScaledDot);
TF_RET_CHECK(scaled_dot != nullptr);

TF_ASSIGN_OR_RETURN(
std::vector<BlasLt::MatmulAlgorithm> algorithms,
plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms,
workspace_size));
int num_algorithms = algorithms.size();
std::vector<std::unique_ptr<BackendConfig>> configs;
configs.reserve(num_algorithms);
for (int i = 0; i < num_algorithms; ++i) {
HipblasLtBackendConfig gemm_key;
gemm_key.set_algorithm(i);
gemm_key.set_autotune_workspace_size(workspace_size);
auto any = std::make_unique<google::protobuf::Any>();
any->PackFrom(gemm_key);
configs.push_back(std::move(any));
if (!IsValidMxScaledDot(scaled_dot)) {
return std::vector<std::unique_ptr<BackendConfig>>();
}

const Shape& lhs_shape = scaled_dot->operand(0)->shape();
const Shape& rhs_shape = scaled_dot->operand(1)->shape();
const DotDimensionNumbers& dot_dims = scaled_dot->dot_dimension_numbers();
const Shape& output_shape = scaled_dot->shape();

auto gemm_config_or = GemmConfig::For(
lhs_shape, dot_dims.lhs_batch_dimensions(),
dot_dims.lhs_contracting_dimensions(), rhs_shape,
dot_dims.rhs_batch_dimensions(), dot_dims.rhs_contracting_dimensions(),
output_shape,
/*alpha_real=*/1.0, /*alpha_imag=*/0.0, /*beta=*/0.0,
PrecisionConfig::ALG_UNSET, /*algorithm=*/std::nullopt,
se::blas::kDefaultComputePrecision, /*grad_x=*/false,
/*grad_y=*/false, /*scale_mode=*/se::gpu::ScaleMode::kBlockScaling,
target_config().device_description.gpu_compute_capability());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GetSupportedConfigs path for scaled dot fusions creates a GemmConfig using the shapes from the fused scaled_dot instruction, but the ApplyConfig path replaces the outer fusion instruction with a custom call using the outer instruction's operands. This works correctly because the outer fusion's operands are forwarded directly to the fused computation's parameters (operand 0 of the fusion is parameter 0 inside, etc.).

However, there is an implicit assumption that the fusion has no intermediate operations between its parameters and the scaled_dot (i.e., the fused computation is just the scaled_dot with pass-through parameters). If a future optimization adds a transpose or reshape inside the fusion before the scaled_dot, the shapes from scaled_dot->operand(i)->shape() (which are the parameter shapes) would still match the outer operand shapes, but the correspondence could break if intermediate ops change the shape.

This is fine for now but worth documenting as a constraint.

@claude
Copy link

claude bot commented Mar 13, 2026

Re-review Summary (Claude)

Re-reviewed the latest diff. All previously raised findings remain applicable — no new issues found beyond existing inline comments. Key areas still worth attention:

  • FP8 backward compatibility: Scale mode gating in GetAlgorithms may break legacy FP8 matmuls that lack scale_mode metadata
  • Proto type safety: scale_mode as raw int32 instead of proto enum
  • Platform guard gap: IsCublasLtMatmulMx and IsScaledDotFusion lack CUDA platform guards
  • FP4 test tolerances: Same tight tolerances as FP8 despite much lower precision
  • Scale pointer swap logic: && guard could silently skip swap when only one scale is non-null

No code changes detected since last review that would resolve these items.

Comment on lines +421 to +422
if (a_scale != nullptr && b_scale != nullptr) {
std::swap(a_scale, b_scale);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scale pointer swap condition a_scale != nullptr && b_scale != nullptr means the swap is silently skipped when only one scale is non-null. For block-scaled MX matmuls, both scales should always be present, but if a bug or edge case results in only one being set, the operands a and b would be swapped while the single scale stays with the wrong operand, producing silently incorrect results.

Consider either always swapping (unconditionally, after a/b swap) or logging/asserting that both must be non-null when must_swap_operands_ is true and at least one scale is present:

Suggested change
if (a_scale != nullptr && b_scale != nullptr) {
std::swap(a_scale, b_scale);
if (a_scale != nullptr || b_scale != nullptr) {
std::swap(a_scale, b_scale);
}

Comment on lines +117 to 118
int32 scale_mode = 21;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a raw int32 with magic values (0=none, 1=tensor_scaling, 2=block_scaling) sacrifices proto type safety and makes the wire format fragile. A proto enum would provide self-documenting field values, validation against unknown values, and proper pretty-printing in debug logs and HLO dumps.

Similarly, the GemmConfigProto in gpu_blas_lt.proto at field 14 uses the same raw int32 pattern. Both should ideally share a single proto enum definition.

If this is a deliberate temporary choice (e.g., to avoid cross-file proto dependencies), documenting that intent would help.


TEST_F(HipblasLtMxExecutionTest, MxFp4Correctness) {
RunMxCorrectnessTest(kMxFp4Hlo, ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5));
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP4 (F4E2M1FN) has only 4 bits of precision (~1.5 mantissa bits), so the tolerances aabs=1e-4, arel=1e-5 are identical to the FP8 tests. This seems overly tight for FP4 and may cause flaky failures depending on the input distribution. Typical FP4 relative errors are on the order of ~5-10%.

Consider using looser tolerances for the FP4 variants (e.g., aabs=0.5, arel=0.1) or documenting why these tight tolerances are appropriate (e.g., the reference also uses FP4 decomposition so the error cancels).

std::vector<BlasLt::MatmulAlgorithm> algorithms,
plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms,
workspace_size));
int num_algorithms = algorithms.size();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int num_algorithms is a narrowing conversion from size_t. XLA convention is to use int64_t for sizes and indices. The scaled-dot path below (line ~295) correctly uses int64_t with static_cast. This line should be consistent:

Suggested change
int num_algorithms = algorithms.size();
int64_t num_algorithms = algorithms.size();

TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config));
HloInstruction* gte = parent->AddInstruction(
HloInstruction::CreateGetTupleElement(result_shape, custom_call, 0));
return parent->ReplaceInstruction(&instr, gte);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ApplyConfig path for scaled dot fusions replaces the fusion instruction with a kCublasLtMatmulMxCallTarget custom call + GTE. This is correct per the ApplyConfig contract, but there is a subtle issue: the instr reference passed to ApplyConfig becomes invalid after ReplaceInstruction deletes it. If the autotuner framework accesses instr after ApplyConfig returns, this would be use-after-free.

Please verify that the autotuner framework does not access the instruction after ApplyConfig returns OkStatus(). The existing TritonBackend::ApplyConfig has the same pattern (it also replaces the instruction), so this may already be safe, but worth confirming.

Comment on lines +181 to +191
bool IsScaledDotFusion(const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) return false;
auto gpu_config = instr.backend_config<GpuBackendConfig>();
if (!gpu_config.ok()) return false;
if (gpu_config->fusion_backend_config().kind() != kTritonGemmFusionKind) {
return false;
}
return hlo_query::GetFirstInstructionWithOpcode(
*instr.fused_instructions_computation(), HloOpcode::kScaledDot) !=
nullptr;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IsScaledDotFusion helper function (defined above in anonymous namespace) checks for __triton_gemm fusions containing kScaledDot, but does not verify that the hardware is ROCm. The IsSupported method does add a ROCm guard, but IsScaledDotFusion is also called independently in GetSupportedConfigs and ApplyConfig without a platform guard. This is not currently harmful because all callers are gated by IsSupported, but it means the helper's name is somewhat misleading -- it identifies the HLO pattern rather than a truly supported fusion. A defensive check inside IsScaledDotFusion (or renaming it to HasScaledDotPattern) would clarify intent and prevent future misuse.

Comment on lines +696 to 716
GetShapedSliceForHlo(instr, {instr->shape().tuple_shapes_size() - 1}));
}

ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue,
gpublas_lt::AsBlasLtEpilogue(epilogue));
Thunk::ThunkInfo thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(
instr, ir_emitter_context_->GetNextThunkId());
std::string canonical_hlo = instr->ToString(
HloPrintOptions::Fingerprint().set_print_backend_config(true));
auto thunk = std::make_unique<CublasLtMatmulThunk>(
std::move(thunk_info), std::move(canonical_hlo), std::move(gemm_config),
blas_lt_epilogue, algorithm, config.autotune_workspace_size(), a, b, c, d,
/*bias=*/std::nullopt, /*aux=*/std::nullopt, a_scale, b_scale,
/*c_scale=*/std::nullopt, /*d_scale=*/std::nullopt,
/*d_amax=*/std::nullopt, workspace_buffer);
return GetThunkSequence(std::move(thunk));
}

absl::StatusOr<ThunkSequence> ThunkEmitter::EmitConvolutionReorderThunk(
const HloCustomCallInstruction* instr) {
bool has_bias = instr->operand_count() > 1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In EmitCublasLtMatmulThunkMx, the DoMatmul path needs the HIPBLASLT_MATMUL_DESC_A_SCALE_MODE / HIPBLASLT_MATMUL_DESC_B_SCALE_MODE attributes set on op_desc_ for block scaling to work. These attributes are only set in GetAlgorithms, not in DoMatmul. This works because op_desc_ is a persistent member of MatmulPlan whose state carries over -- but this is fragile. If the plan is created fresh (e.g., during thunk execution when not autotuning), GetAlgorithms may not be called first, and the scale mode attributes would never be set, leading to incorrect computation.

Please verify that GetAlgorithms is always called before DoMatmul for MX plans, or consider also setting the scale mode attributes in DoMatmul (or in GetMatmulPlan / MatmulDesc::Create).

Comment on lines +267 to +286
case gpu::ScaleMode::kBlockScaling: {
#if TF_ROCM_VERSION >= 70000
static int64_t dummy_pointer = 0xACEBALL;
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&dummy_pointer));
TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(),
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&dummy_pointer));
hipblasLtMatmulMatrixScale_t mx_scale =
HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale));
TF_RETURN_IF_ERROR(SetAttr(
op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale));
#else
return absl::InternalError("Block scaling requires ROCm >= 7.0");
#endif
break;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block scaling SCALE_MODE attributes (HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_DESC_B_SCALE_MODE) are only set here in GetAlgorithms, not in DoMatmul. While op_desc_ is a persistent member whose state carries over, this creates an implicit dependency: GetAlgorithms must be called before DoMatmul for the scale mode to be correctly configured. If a plan is constructed and DoMatmul is called without first calling GetAlgorithms (possible in non-autotuning paths), the scale mode would default to none, producing silently incorrect results.

Consider setting these attributes in MatmulDesc::Create (alongside the constructor) or in the plan construction path in GetMatmulPlan, so the scale mode is always configured regardless of call order.

Comment on lines +169 to +182
}
if (lhs_scale_k == 0 || k / lhs_scale_k != 32 || rhs_scale_k == 0 ||
k / rhs_scale_k != 32) {
VLOG(2) << "hipBLASLt MX: block size must be 32, got lhs="
<< (lhs_scale_k > 0 ? k / lhs_scale_k : 0)
<< " rhs=" << (rhs_scale_k > 0 ? k / rhs_scale_k : 0);
return false;
}

return true;
}

bool IsScaledDotFusion(const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) return false;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsValidMxScaledDot performs integer division k / lhs_scale_k and k / rhs_scale_k to check that the block size is 32. If k is not evenly divisible by lhs_scale_k, this division truncates and could incorrectly match (or miss). For example, if k=63 and lhs_scale_k=2, then 63/2=31 which would fail, but k=64 and lhs_scale_k=2 gives 64/2=32 which passes. It works for the current "well-formed" inputs, but adding a check k % lhs_scale_k == 0 before the division would make it more robust against malformed shapes.

@claude
Copy link

claude bot commented Mar 13, 2026

Summary

This PR adds ROCm/hipBLASLt support for block-scaled matrix multiplication using MX (Microscaling) formats (FP8 and FP4) on gfx950+ hardware. It introduces a new __cublas$lt$matmul$mx custom call target, a ScaleMode enum for distinguishing none/tensor/block scaling, and routes __triton_gemm fusions containing kScaledDot through the hipBLASLt autotuner backend instead of the Triton backend on ROCm. The PR also adds end-to-end execution tests and unit tests for the new path.

Findings

  • [xla/stream_executor/rocm/hip_blas_lt.cc:421-422] Scale pointer swap guard too restrictive: The && condition a_scale != nullptr && b_scale != nullptr silently skips the swap when only one scale is non-null. If must_swap_operands_ is true and only one scale exists, operands get swapped but the scale does not, producing silently wrong results. Should use || instead of &&.

  • [xla/stream_executor/rocm/hip_blas_lt.cc:267-286] Scale mode attributes only set in GetAlgorithms, not DoMatmul: The HIPBLASLT_MATMUL_DESC_A_SCALE_MODE / HIPBLASLT_MATMUL_DESC_B_SCALE_MODE attributes are set on op_desc_ during GetAlgorithms only. If DoMatmul is ever called without a prior GetAlgorithms call (possible in non-autotuning paths), block scaling would silently not be configured. Consider setting these in MatmulDesc::Create instead.

  • [xla/service/gpu/backend_configs.proto:117-118] Raw int32 instead of proto enum for scale_mode: Both GemmBackendConfig.scale_mode and GemmConfigProto.scale_mode use int32 with magic values documented in comments. A proto enum would provide type safety, self-documentation, and proper debug printing.

  • [xla/backends/gpu/autotuner/hipblaslt_mx_execution_test.cc:152] FP4 test tolerances match FP8: All FP4 tests use ErrorSpec(1e-4, 1e-5) which is identical to the FP8 tests despite FP4 having dramatically lower precision (~1.5 mantissa bits vs ~3-4 for FP8). These may pass now but could be flaky.

  • [xla/backends/gpu/autotuner/hipblaslt.cc:246] int narrowing from size_t: int num_algorithms = algorithms.size() narrows from size_t to int. XLA convention uses int64_t. The scaled-dot path correctly uses int64_t with explicit cast.

  • [xla/backends/gpu/autotuner/hipblaslt.cc:169-182] Integer division without remainder check: k / lhs_scale_k and k / rhs_scale_k truncate silently if k is not evenly divisible by the scale dimensions. Adding k % lhs_scale_k == 0 guards would make the validation more robust.

  • [xla/backends/gpu/autotuner/hipblaslt.cc:181-191] IsScaledDotFusion lacks platform guard: The helper identifies an HLO pattern but does not verify ROCm hardware. While callers are currently gated by IsSupported, the name suggests broader applicability. Consider renaming to HasScaledDotPattern or adding a defensive platform check.

  • [xla/backends/gpu/autotuner/hipblaslt.cc:408] Use-after-free risk in ApplyConfig: ReplaceInstruction deletes instr, which was passed by reference. If the autotuner framework accesses instr after ApplyConfig returns, this is UB. The same pattern exists in TritonBackend::ApplyConfig, so this is likely safe by convention, but worth confirming.

Generated with Claude Code

@claude
Copy link

claude bot commented Mar 13, 2026

Re-review summary (2026-03-13): Re-examined the diff — no code changes since last review (only a BUILD file dependency fix). All prior inline findings remain open and unaddressed. No new issues found.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant