Skip to content

Large number of graph break with flash_attention on dynamo openxla backend #8913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bhavya01 opened this issue Mar 31, 2025 · 2 comments
Open
Assignees

Comments

@bhavya01
Copy link
Collaborator

🐛 Bug

I have a simple script that trains using flash attention kernel in torch_xla. I see 9 graph breaks while running this script

Using regular SDPA attention I only see 4 graph breaks.

Every time we call torch_xla._XLAC._init_computation_client() we hit a graph break

To Reproduce

import torch.nn.functional as F
import torch_xla
import torch.nn as nn
import torch
import math
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import torch_xla.debug.profiler as xp
import logging
torch._logging.set_logs(graph_breaks=True)
xr.use_spmd()

def apply_xla_flash_attention(query_states, key_states, value_states, partition_spec):
  # q, k, v should all have the shape [B, n_head, S, head_dim]
  head_dim = query_states.size()[-1]
  query_states = query_states / math.sqrt(head_dim)
  attn_output = flash_attention(
      query_states, key_states, value_states, causal=False, partition_spec=partition_spec)
  # key_states = key_states.transpose(2, 3)
  # t = query_states @ key_states
  # attn_weight = torch.softmax(t, dim=-1)
  # attn_output = attn_weight @ value_states
  return attn_output

NUM_HEADS=8
HEAD_DIM=128
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.wq = nn.Linear(1024, 1024)
    self.wk = nn.Linear(1024, 1024)
    self.wv = nn.Linear(1024, 1024)

  
  def forward(self, x):
    q = self.wq(x)
    k = self.wk(x)
    v = self.wv(x)
    # b, num_heads, seq_len, head_dim
    b, s, _ = q.size()
    q = q.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    k = k.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    v = v.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    z = apply_xla_flash_attention(q, k, v, ('data', None, None, None))
    z = z.transpose(1, 2).contiguous().view(b, s, NUM_HEADS * HEAD_DIM)
    return z

model = Model()
model.to('xla')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def step_fn(input, label):
  optimizer.zero_grad()
  out = model(input)
  loss = F.mse_loss(out, label)
  loss.backward()
  optimizer.step()
  return loss
step_fn_compiled = torch.compile(step_fn, backend='openxla')
input = torch.randn(8, 256, 1024).to('xla')
label = torch.randn(8, 256, 1024).to('xla')
mesh = xs.get_1d_mesh('data')
xs.set_global_mesh(mesh)
xs.mark_sharding(input, mesh, ('data', None, None))
xs.mark_sharding(label, mesh, ('data', None, None))
loss=None
loss = step_fn_compiled(input, label)
print(loss)

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 03/21 nightly
@bhavya01 bhavya01 self-assigned this Mar 31, 2025
@bhavya01
Copy link
Collaborator Author

bhavya01 commented Apr 1, 2025

The graph breaks is due to requires_jax decorator and needs to be moved to the innermost parts of flash attention where jax is actually imported.

@bhavya01
Copy link
Collaborator Author

bhavya01 commented Apr 7, 2025

In my local wheel, I move the requires_jax inside the flash attention custom op which fixed the additional graph breaks. Now, the only graph breaks I see are the forced ones like optimizer.zero_grad(), loss.backward() and optimizer.step() which generally leads to 4 graphs in a training loop which is not ideal. Each one triggers a compilation in torch_xla and needs the tensors materialized before executing the next graph. This slows down the performance by a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants