Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx.make_causal_mask() usage #4505

Open
windmaple opened this issue Jan 25, 2025 · 3 comments
Open

nnx.make_causal_mask() usage #4505

windmaple opened this issue Jan 25, 2025 · 3 comments

Comments

@windmaple
Copy link

So this is a follow-up on #4290 (@cgarciae). For building a causal LM, I need to use causal masking. Here is my attempt (by adding a single line using the code from #4290:

batch_size = 2
seqlen = 40
emb_size = 256

x = jnp.ones((batch_size, seqlen, emb_size))

mha = nnx.MultiHeadAttention(
  in_features=emb_size, num_heads=2, decode=True, rngs=nnx.Rngs(0)
)
shape = x.shape

 for i in range(seqlen): # iterate all tokens
  y = mha(inputs_q=x[:, i : i + 1],
          mask=nnx.make_causal_mask(x[:, i : i + 1]))   #newly added

The error I got is:

AssertionError: masks must have same rank: (5, 4)

I cannot make sense of this error :(

@cgarciae
Copy link
Collaborator

cgarciae commented Jan 28, 2025

Hi @windmaple, in decode mode MultiHeadAttention is always causal, meaning you don't have to provide a mask in this case. See:

# causal mask for cached decoder self-attention:
# our single query position should only attend to those key
# positions that have already been generated and cached,
# not the remaining zero elements.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(max_length) <= cur_index,
tuple(batch_dims) + (1, 1, max_length),
),
)

@windmaple
Copy link
Author

Yeah, I realized that, since we are feeding token in one by one.

However, for some reason it's not working as expected. I'll try to provide a repro.

@windmaple
Copy link
Author

Here is the notebook:
https://colab.research.google.com/drive/1kk7xcFSA7KzVQnekfqmdd1Gq_Z4qsLvU#scrollTo=NIOXoY1xgiww

Turning on KV cache makes it so much slower, which doesn't make any sense to me :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants