Skip to content

Large number of graph break with flash_attention on dynamo openxla backend #8913

Open
@bhavya01

Description

@bhavya01

🐛 Bug

I have a simple script that trains using flash attention kernel in torch_xla. I see 9 graph breaks while running this script

Using regular SDPA attention I only see 4 graph breaks.

Every time we call torch_xla._XLAC._init_computation_client() we hit a graph break

To Reproduce

import torch.nn.functional as F
import torch_xla
import torch.nn as nn
import torch
import math
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import torch_xla.debug.profiler as xp
import logging
torch._logging.set_logs(graph_breaks=True)
xr.use_spmd()

def apply_xla_flash_attention(query_states, key_states, value_states, partition_spec):
  # q, k, v should all have the shape [B, n_head, S, head_dim]
  head_dim = query_states.size()[-1]
  query_states = query_states / math.sqrt(head_dim)
  attn_output = flash_attention(
      query_states, key_states, value_states, causal=False, partition_spec=partition_spec)
  # key_states = key_states.transpose(2, 3)
  # t = query_states @ key_states
  # attn_weight = torch.softmax(t, dim=-1)
  # attn_output = attn_weight @ value_states
  return attn_output

NUM_HEADS=8
HEAD_DIM=128
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.wq = nn.Linear(1024, 1024)
    self.wk = nn.Linear(1024, 1024)
    self.wv = nn.Linear(1024, 1024)

  
  def forward(self, x):
    q = self.wq(x)
    k = self.wk(x)
    v = self.wv(x)
    # b, num_heads, seq_len, head_dim
    b, s, _ = q.size()
    q = q.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    k = k.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    v = v.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
    z = apply_xla_flash_attention(q, k, v, ('data', None, None, None))
    z = z.transpose(1, 2).contiguous().view(b, s, NUM_HEADS * HEAD_DIM)
    return z

model = Model()
model.to('xla')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def step_fn(input, label):
  optimizer.zero_grad()
  out = model(input)
  loss = F.mse_loss(out, label)
  loss.backward()
  optimizer.step()
  return loss
step_fn_compiled = torch.compile(step_fn, backend='openxla')
input = torch.randn(8, 256, 1024).to('xla')
label = torch.randn(8, 256, 1024).to('xla')
mesh = xs.get_1d_mesh('data')
xs.set_global_mesh(mesh)
xs.mark_sharding(input, mesh, ('data', None, None))
xs.mark_sharding(label, mesh, ('data', None, None))
loss=None
loss = step_fn_compiled(input, label)
print(loss)

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 03/21 nightly

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions