-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Q][GPU][BF16] torch.mul is lowered to HLO as an f32 multiply #8545
Comments
I tried the same thing using autocast, and it seems to be working as you expect. Below is the code to replicate.
Below is the HLO:
Flags: Can you try replicating and confirm. |
This is actually expected behavior. In fact, PyTorch CUDA also does the same thing. In summary, PyTorch converts each operand (stored as |
❓ Questions and Help
torch 2.5.1
torch_xla 2.5.1
cuda 12.4
GPU NVIDIA L4
The following example uses
torch.mul
where both operands are bf16, but in the HLO graph, I see an f32 multiply operation.hlo: module_0000.SyncTensorsGraph.16.before_optimizations.txt
I was able to achieve bf16 multiplication by setting
export XLA_USE_BF16=1
, but I received the following warningI'm not sure how I can enable bf16 multiplication in HLO (High-Level Optimizer) in the correct way, without using the deprecated flag.
The text was updated successfully, but these errors were encountered: