Skip to content

Question about Flux attention implementation #2164

@CHR-ray

Description

@CHR-ray

Hi,

When I read the code at

def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
"""
q, k, v: [B, L, D]
"""
pre_attn_layout = MEMORY_LAYOUTS[mode][0]
post_attn_layout = MEMORY_LAYOUTS[mode][1]
q = pre_attn_layout(q, head_dim)
k = pre_attn_layout(k, head_dim)
v = pre_attn_layout(v, head_dim)
# scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
if mode == "torch":
assert scale is None
scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
elif mode == "xformers":
scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
else:
scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
scores = post_attn_layout(scores)
return scores

I see 3 different optimization approach for attention block in sd3.

But for Flux model, which also have similar MMDiT block, I see attention implementation is

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x

which only contains SDPA approach, no xformers approach.

I search for keywords like "xformers" and "flux", but it seems like no one talk about this difference.

So, can I ask the reason behind it? In my opinion, same structure can benefit from same optimization approach. If it is possible to add xformers for flux attention?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions