Skip to content

Conversation

@zhuzilin
Copy link

@zhuzilin zhuzilin commented Oct 9, 2024

This PR is trying to implement a FlashDiffAttention class similar to the FlashSelfAttention in the origin flash attention repo (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py#L53), so that training frameworks could easily add diff transformer support with and without varlen support.

The main idea is to set the num_head in the training process twice as the origin transformer so that we no longer need to change the code relates to RoPE.

A simple test script for the code is:

from dataclasses import dataclass

import torch
import torch.distributed as dist
from flash_attn.layers.rotary import RotaryEmbedding
from einops import rearrange

from multihead_flashdiff_2 import MultiheadFlashDiff2
from flashdiff import FlashDiffAttention
from kernel.rotary import apply_rotary_emb


@dataclass
class Args:
    model_parallel_size: int
    decoder_kv_attention_heads: int


def create_new_impl(origin_impl, head_dim, depth):
    diff_attn_func = FlashDiffAttention(
        head_dim=embed_dim // num_new_heads, depth=depth, causal=True
    ).to(device, dtype=dtype)
    # make the initialization the same
    diff_attn_func.lambda_q1.data.copy_(origin_impl.lambda_q1.data)
    diff_attn_func.lambda_k1.data.copy_(origin_impl.lambda_k1.data)
    diff_attn_func.lambda_q2.data.copy_(origin_impl.lambda_q2.data)
    diff_attn_func.lambda_k2.data.copy_(origin_impl.lambda_k2.data)
    #diff_attn_func.subln.weight.data.copy_(origin_impl.subln.weight.data)
    
    def new_impl(x, rel_pos):
        bsz, tgt_len, embed_dim = x.size()
        src_len = tgt_len

        q = origin_impl.q_proj(x)
        k = origin_impl.k_proj(x)
        v = origin_impl.v_proj(x)

        # here we no longer need "// 2"
        num_heads = embed_dim // head_dim
        num_kv_heads = k.shape[-1] // head_dim

        q = q.view(bsz, tgt_len, num_heads, head_dim)
        k = k.view(bsz, src_len, num_kv_heads, head_dim)
        v = v.view(bsz, src_len, num_kv_heads, head_dim)

        q = apply_rotary_emb(q, *rel_pos, interleaved=True)
        k = apply_rotary_emb(k, *rel_pos, interleaved=True)

        output = diff_attn_func(q, k, v)
        output = rearrange(output, '... H D -> ... (H D)')

        output = origin_impl.out_proj(output)
        return output
    
    return new_impl


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    device = torch.device("cuda")
    dtype = torch.bfloat16
    args = Args(model_parallel_size=1, decoder_kv_attention_heads=4)
    batch_size = 2
    num_heads = 16
    seq_len = 512
    embed_dim = 2048
    depth = 12
    # in the new implementation, the num_heads should be twice the original num_heads
    num_new_heads = num_heads * 2
    head_dim = embed_dim // num_new_heads

    print("initializing modules")
    origin_impl = MultiheadFlashDiff2(args, embed_dim=embed_dim, depth=depth, num_heads=num_heads).to(device, dtype=dtype)
    new_impl = create_new_impl(origin_impl, head_dim, depth)

    print("creating test data")
    rotary_emb = RotaryEmbedding(
        head_dim,
        base=10000.0,
        interleaved=True,
        device=device,
    )
    rotary_emb._update_cos_sin_cache(seq_len, device=device, dtype=torch.bfloat16)
    rel_pos = (rotary_emb._cos_cached, rotary_emb._sin_cached)
    hidden_states = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype)

    print("run origin forward")
    origin_output = origin_impl(hidden_states, rel_pos)

    print("run new forward")
    new_output = new_impl(hidden_states, rel_pos)

    assert torch.allclose(origin_output, new_output, atol=1e-6)

Thank you for your time on reviewing this PR.

@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch 2 times, most recently from 200479b to f9f35a8 Compare October 9, 2024 06:29
@zhuzilin zhuzilin changed the title [WIP] Add a isolated implementation of FlashDiffAttention Add an isolated implementation of FlashDiffAttention Oct 9, 2024
@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch from f9f35a8 to c6e6486 Compare October 9, 2024 06:36
@MarktHart
Copy link

You could go even closer to attention and use it as is with a doubled interleave. E.g.

def alternative_forward(
        self,
        x,
        rel_pos,
        attn_mask=None,
    ):
    bsz, tgt_len, embed_dim = x.size()
    src_len = tgt_len

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
    k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
    v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)

    q = apply_rotary_emb(q, *rel_pos, interleaved=True)
    k = apply_rotary_emb(k, *rel_pos, interleaved=True)

    q = q.transpose(1, 2)
    
    k = torch.repeat_interleave(k.transpose(1, 2), dim=1, repeats=self.n_rep)
    v = torch.repeat_interleave(v.transpose(1, 2), dim=1, repeats=self.n_rep * 2)
    if attn_mask is None:
        attn_mask = torch.triu(
            torch.zeros([tgt_len, src_len])
            .float()
            .fill_(float("-inf"))
            .type_as(q),
            1 + src_len - tgt_len,
        )

    lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
    lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
    lambda_full = lambda_1 - lambda_2 + self.lambda_init

    attn_weights = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_mask, scale=self.scaling)
    every_other_mask = torch.arange(attn_weights.size(1)) % 2 == 0
    attn = attn_weights[:, every_other_mask, :, :] - lambda_full * attn_weights[:, ~every_other_mask, :, :]

    attn = self.subln(attn)
    attn = attn * (1 - self.lambda_init)
    attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)

    attn = self.out_proj(attn)
    return attn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants