Open
Description
🐛 Bug
I have a simple script that trains using flash attention kernel in torch_xla. I see 9 graph breaks while running this script
Using regular SDPA attention I only see 4 graph breaks.
Every time we call torch_xla._XLAC._init_computation_client()
we hit a graph break
To Reproduce
import torch.nn.functional as F
import torch_xla
import torch.nn as nn
import torch
import math
from torch_xla.experimental.custom_kernel import flash_attention
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
import torch_xla.debug.profiler as xp
import logging
torch._logging.set_logs(graph_breaks=True)
xr.use_spmd()
def apply_xla_flash_attention(query_states, key_states, value_states, partition_spec):
# q, k, v should all have the shape [B, n_head, S, head_dim]
head_dim = query_states.size()[-1]
query_states = query_states / math.sqrt(head_dim)
attn_output = flash_attention(
query_states, key_states, value_states, causal=False, partition_spec=partition_spec)
# key_states = key_states.transpose(2, 3)
# t = query_states @ key_states
# attn_weight = torch.softmax(t, dim=-1)
# attn_output = attn_weight @ value_states
return attn_output
NUM_HEADS=8
HEAD_DIM=128
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.wq = nn.Linear(1024, 1024)
self.wk = nn.Linear(1024, 1024)
self.wv = nn.Linear(1024, 1024)
def forward(self, x):
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
# b, num_heads, seq_len, head_dim
b, s, _ = q.size()
q = q.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
k = k.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
v = v.view(b, s, NUM_HEADS, HEAD_DIM).transpose(1, 2)
z = apply_xla_flash_attention(q, k, v, ('data', None, None, None))
z = z.transpose(1, 2).contiguous().view(b, s, NUM_HEADS * HEAD_DIM)
return z
model = Model()
model.to('xla')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def step_fn(input, label):
optimizer.zero_grad()
out = model(input)
loss = F.mse_loss(out, label)
loss.backward()
optimizer.step()
return loss
step_fn_compiled = torch.compile(step_fn, backend='openxla')
input = torch.randn(8, 256, 1024).to('xla')
label = torch.randn(8, 256, 1024).to('xla')
mesh = xs.get_1d_mesh('data')
xs.set_global_mesh(mesh)
xs.mark_sharding(input, mesh, ('data', None, None))
xs.mark_sharding(label, mesh, ('data', None, None))
loss=None
loss = step_fn_compiled(input, label)
print(loss)
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: 03/21 nightly