Fix issue with missing upcast/downcast for bf16 libdevice calls.#661
Fix issue with missing upcast/downcast for bf16 libdevice calls.#661zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.8.2from
Conversation
| if (res.getType() != output_type || | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { | ||
| (output_type.isBF16() || | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { |
There was a problem hiding this comment.
nit: fragile downcast condition — The output_type.isBF16() disjunction here can be true even when res.getType() == output_type. In practice this doesn't fire because bf16 always enters the upcast path above, so res will have type f32 and res.getType() != output_type is already sufficient. But the condition is fragile: if someone later changes the upcast logic without updating this downcast logic, the isBF16() arm could trigger a spurious bf16→bf16 cast.
Consider simplifying the downcast condition to just res.getType() != output_type, which is correct in all cases and doesn't duplicate the logic of the upcast block.
|
Claude Code Review for PR 661: This PR correctly adds bf16 upcast/downcast handling alongside the existing f16 path for libdevice calls. The core logic is sound. One inline comment posted on the downcast condition (lines 216-219) suggesting simplification. Additionally, since no test file changes are included and this targets ROCm, consider adding AMDGCN bf16 test cases to triton_xla_math_to_libdevice.mlir for stronger coverage. |
i-chaochen
left a comment
There was a problem hiding this comment.
Thanks! we don't run this UT //xla/backends/gpu/codegen/triton/transforms/tests:triton_xla_math_to_libdevice.mlir.test on our CI?
Motivation
Introduced missing upcast/downcast for bf16 type
Technical Details
upcast/downcast are necessary because there is no native bf16 implementation in libdevice
Test Plan
triton_xla_math_to_libdevice.mlir
Test Result
Test pass
Submission Checklist