Skip to content

Fix issue with missing upcast/downcast for bf16 libdevice calls.#661

Open
zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.8.2from
rocm-jaxlib-v0.8.2-math_to_libdevice
Open

Fix issue with missing upcast/downcast for bf16 libdevice calls.#661
zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.8.2from
rocm-jaxlib-v0.8.2-math_to_libdevice

Conversation

@zoranjovanovic-ns
Copy link

Motivation

Introduced missing upcast/downcast for bf16 type

Technical Details

upcast/downcast are necessary because there is no native bf16 implementation in libdevice

Test Plan

triton_xla_math_to_libdevice.mlir

Test Result

Test pass

Submission Checklist

@zoranjovanovic-ns zoranjovanovic-ns added cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. Upstream rocm-jaxlib-v0.8.2 labels Mar 9, 2026
Comment on lines 216 to +219
if (res.getType() != output_type ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fragile downcast condition — The output_type.isBF16() disjunction here can be true even when res.getType() == output_type. In practice this doesn't fire because bf16 always enters the upcast path above, so res will have type f32 and res.getType() != output_type is already sufficient. But the condition is fragile: if someone later changes the upcast logic without updating this downcast logic, the isBF16() arm could trigger a spurious bf16→bf16 cast.

Consider simplifying the downcast condition to just res.getType() != output_type, which is correct in all cases and doesn't duplicate the logic of the upcast block.

@claude
Copy link

claude bot commented Mar 9, 2026

Claude Code Review for PR 661: This PR correctly adds bf16 upcast/downcast handling alongside the existing f16 path for libdevice calls. The core logic is sound. One inline comment posted on the downcast condition (lines 216-219) suggesting simplification. Additionally, since no test file changes are included and this targets ROCm, consider adding AMDGCN bf16 test cases to triton_xla_math_to_libdevice.mlir for stronger coverage.

Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! we don't run this UT //xla/backends/gpu/codegen/triton/transforms/tests:triton_xla_math_to_libdevice.mlir.test on our CI?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. rocm-jaxlib-v0.8.2 Upstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants