Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Triton hstu attention test encouneter numerical error when is_causal=False #173

Open
shijieliu opened this issue Jan 6, 2025 · 3 comments

Comments

@shijieliu
Copy link

Hi

I noticed the triton implementation of hstu attention support causal but when I tried to test its correctness, I found there is numerical error when is_causal=False. I also noticed the perf when is_causal=False is better than is_causal=True which does not make sense to me. So I am suspecting there may be some bug in triton implementation when is_causal=False.

Log

self = <hstu_attention_test.HSTUAttentionTest testMethod=test_attn_triton>, batch_size = 4, heads = 1
max_uih_len = 100, max_targets = 20, attn_dim = 16, hidden_dim = 16, causal = False
has_multiple_targets = True, has_max_attn_len = False, dtype = torch.bfloat16, test_backward = True
ref_kernel = <HammerKernel.PYTORCH: 'PYTORCH'>, real_kernel = <HammerKernel.TRITON: 'TRITON'>
skip_comparisons = False, sparsity = -1.0, contextual_seq_len = 0, atol = None, rtol = None

    def _test_attn(
        self,
        batch_size: int,
        heads: int,
        max_uih_len: int,
        max_targets: int,
        attn_dim: int,
        hidden_dim: int,
        causal: bool,
        has_multiple_targets: bool,
        has_max_attn_len: bool,
        dtype: torch.dtype,
        test_backward: bool,
        ref_kernel: HammerKernel,
        real_kernel: HammerKernel,
        skip_comparisons: bool = False,
        sparsity: float = -1.0,
        contextual_seq_len: int = 0,
        atol: Optional[float] = None,
        rtol: Optional[float] = None,
    ) -> None:
        set_dev_mode(True)
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.matmul.allow_tf32 = True
        from generative_recommenders.ops.hstu_attention import hstu_mha

        alpha = 1.0 / (attn_dim**0.5)
        if sparsity > 0.0:
            lengths = generate_sparse_seq_len(
                size=batch_size,
                max_seq_len=max_uih_len,
                sparsity=sparsity,
                device=torch.device("cuda"),
            )
        else:
            lengths = torch.randint(
                max_uih_len + 1, size=(batch_size,), device=torch.device("cuda")
            )
        num_targets = torch.randint(
            max_targets + 1, size=(batch_size,), device=torch.device("cuda")
        )
        lengths = lengths + num_targets
        max_seq_len = max_uih_len + max_targets
        if has_max_attn_len:
            max_attn_len = random.randint(1, max_uih_len // 5)
        else:
            max_attn_len = None
        seq_offsets = torch.zeros(
            (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda")
        )
        seq_offsets[1:] = torch.cumsum(lengths, dim=0)

        L = int(seq_offsets[-1].item())
        q = (
            torch.empty((L, heads, attn_dim), dtype=dtype, device=torch.device("cuda"))
            .uniform_(-0.1, 0.1)
            .requires_grad_()
        )
        k = (
            torch.empty((L, heads, attn_dim), dtype=dtype, device=torch.device("cuda"))
            .uniform_(-0.1, 0.1)
            .requires_grad_()
        )
        v = (
            torch.empty(
                (L, heads, hidden_dim), dtype=dtype, device=torch.device("cuda")
            )
            .uniform_(-0.1, 0.1)
            .requires_grad_()
        )

        # ref implementation
        ref_out = hstu_mha(
            max_seq_len=max_seq_len,
            alpha=alpha,
            q=q,
            k=k,
            v=v,
            seq_offsets=seq_offsets,
            causal=causal,
            num_targets=num_targets if has_multiple_targets else None,
            dropout_pr=0.0,
            max_attn_len=max_attn_len,
            contextual_seq_len=contextual_seq_len,
            kernel=ref_kernel,
        )
        dout = torch.randn_like(ref_out) * 0.01
        ref_out.backward(dout)

        if skip_comparisons:
            return

        # pyre-ignore
        ref_dv, v.grad = v.grad.clone(), None
        ref_dk, k.grad = k.grad.clone(), None
        ref_dq, q.grad = q.grad.clone(), None

        # triton implementation
        q = q.detach().clone().requires_grad_()
        k = k.detach().clone().requires_grad_()
        v = v.detach().clone().requires_grad_()
        dout = dout.detach().clone()
        real_out = hstu_mha(
            max_seq_len=max_seq_len,
            alpha=alpha,
            q=q,
            k=k,
            v=v,
            seq_offsets=seq_offsets,
            causal=causal,
            num_targets=num_targets if has_multiple_targets else None,
            dropout_pr=0.0,
            max_attn_len=max_attn_len,
            contextual_seq_len=contextual_seq_len,
            kernel=real_kernel,
        )

>       torch.testing.assert_close(
            ref_out,
            real_out,
            atol=atol,
            rtol=rtol,
        )
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 491 / 3936 (12.5%)
E       Greatest absolute difference: 2.47955322265625e-05 at index (85, 0, 7) (up to 1e-05 allowed)
E       Greatest relative difference: 21376.0 at index (199, 0, 3) (up to 0.016 allowed)

hstu_attention_test.py:217: AssertionError

Reproduce step

  1. comment this line to bypass assertertion on causal
  2. change causal to False in tests by setting this line to False
  3. run hstu_attention_test.py
@jiaqizhai
Copy link
Contributor

Hi, thanks for looking into this and identifying this issue! These kernels are primarily tested/benchmarked/used with is_causal=True, and the other path is not tested as thoroughly. I would recommend sticking with is_causal=True (given equivalence) unless you have special scenarios in mind.

@jiaqizhai
Copy link
Contributor

Note also during inference (where we use M-FALCON etc.), is_causal=False would make sequence construction significantly more complicated.

@shijieliu
Copy link
Author

Thanks. It makes sense to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants