You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my local wheel, I move the requires_jax inside the flash attention custom op which fixed the additional graph breaks. Now, the only graph breaks I see are the forced ones like optimizer.zero_grad(), loss.backward() and optimizer.step() which generally leads to 4 graphs in a training loop which is not ideal. Each one triggers a compilation in torch_xla and needs the tensors materialized before executing the next graph. This slows down the performance by a lot.
🐛 Bug
I have a simple script that trains using flash attention kernel in torch_xla. I see 9 graph breaks while running this script
Using regular SDPA attention I only see 4 graph breaks.
Every time we call
torch_xla._XLAC._init_computation_client()
we hit a graph breakTo Reproduce
Environment
The text was updated successfully, but these errors were encountered: