Skip to content

Commit 9bab148

Browse files
committed
Add is_causal mask argument and tests
1 parent bee81ed commit 9bab148

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

flax/nnx/nn/attention.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def dot_product_attention_weights(
6262
precision: PrecisionLike = None,
6363
module: Module | None = None,
6464
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
65+
is_causal: bool = False,
6566
):
6667
"""Computes dot-product attention weights given query and key.
6768
@@ -94,6 +95,11 @@ def dot_product_attention_weights(
9495
promote_dtype: function to promote the dtype of the arrays to the desired
9596
dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype``
9697
keyword argument, and return a tuple of arrays with the promoted dtype.
98+
is_causal: If true, causal attention will be applied. Note, some
99+
implementations like xla will generate a mask tensor and apply it to
100+
the logits to mask out the non-causal parts of the attention matrix,
101+
but other implementations like cudnn will avoid computing the
102+
non-causal regions, providing speedups.
97103
98104
Returns:
99105
Output of shape `[batch..., num_heads, q_length, kv_length]`.
@@ -118,9 +124,17 @@ def dot_product_attention_weights(
118124
if bias is not None:
119125
attn_weights = attn_weights + bias
120126
# apply attention mask
121-
if mask is not None:
127+
if mask is not None or is_causal:
122128
big_neg = jnp.finfo(dtype).min
123-
attn_weights = jnp.where(mask, attn_weights, big_neg)
129+
masks = [m for m in [mask] if m is not None]
130+
if is_causal:
131+
T, S = attn_weights.shape[-2:]
132+
causal_mask = jnp.tril(jnp.ones((T, S), dtype=dtype))
133+
target_shape = mask.shape if mask is not None else attn_weights.shape
134+
masks.append(jnp.broadcast_to(causal_mask, target_shape))
135+
combined_mask = combine_masks(*masks, dtype=dtype)
136+
assert combined_mask is not None
137+
attn_weights = jnp.where(combined_mask, attn_weights, big_neg)
124138

125139
# normalize the attention weights
126140
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
@@ -157,6 +171,7 @@ def dot_product_attention(
157171
precision: PrecisionLike = None,
158172
module: Module | None = None,
159173
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
174+
is_causal: bool = False,
160175
):
161176
"""Computes dot-product attention given query, key, and value.
162177
@@ -198,6 +213,11 @@ def dot_product_attention(
198213
dtype. The function should accept a tuple of ``(query, key, value)`` and a
199214
``dtype`` keyword argument, and return a tuple of arrays with the promoted
200215
dtype.
216+
is_causal: If true, causal attention will be applied. Note, some
217+
implementations like xla will generate a mask tensor and apply it to
218+
the logits to mask out the non-causal parts of the attention matrix,
219+
but other implementations like cudnn will avoid computing the
220+
non-causal regions, providing speedups.
201221
202222
Returns:
203223
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
@@ -224,7 +244,7 @@ def reshape_4d(x):
224244
reshape_4d, (query, key, value, bias, mask))
225245
if mask is not None:
226246
mask = mask.astype(jnp.bool)
227-
out = jax.nn.dot_product_attention(query, key, value, bias, mask)
247+
out = jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal)
228248
if len(query_shape) > 4:
229249
out = jnp.reshape(out, query_shape)
230250
return out
@@ -242,6 +262,8 @@ def reshape_4d(x):
242262
dtype,
243263
precision,
244264
module,
265+
promote_dtype,
266+
is_causal,
245267
)
246268

247269
# return weighted sum over values for each query position

tests/nnx/nn/attention_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from flax import linen
1919
from flax import nnx
20+
from flax.nnx.nn.attention import combine_masks
2021
from flax.typing import Dtype, PrecisionLike
2122

2223
import numpy as np
@@ -132,6 +133,78 @@ def test_keep_rngs(self, keep_rngs):
132133
else:
133134
nnx.split(module, nnx.Param)
134135

136+
@parameterized.product(use_padding=[True, False], is_cross_attention=[True, False])
137+
def test_causal_mask_equivalence(
138+
self,
139+
use_padding: bool,
140+
is_cross_attention: bool
141+
):
142+
batch_size = 1
143+
num_heads = 2
144+
q_len = 2
145+
kv_len = 4 if is_cross_attention else q_len
146+
head_dim = 4
147+
148+
q = jax.random.normal(
149+
key=jax.random.key(0),
150+
shape=(batch_size, 1, q_len, num_heads, head_dim)
151+
)
152+
k = jax.random.normal(
153+
key=jax.random.key(1),
154+
shape=(batch_size, 1, kv_len, num_heads, head_dim)
155+
)
156+
v = jax.random.normal(
157+
key=jax.random.key(2),
158+
shape=(batch_size, 1, kv_len, num_heads, head_dim)
159+
)
160+
161+
causal_mask = jnp.tril(jnp.ones(
162+
shape=(q_len, kv_len),
163+
dtype=jnp.bool_
164+
)
165+
)
166+
causal_mask = jnp.broadcast_to(
167+
array=causal_mask,
168+
shape=(batch_size, 1, num_heads, q_len, kv_len)
169+
)
170+
171+
padding_mask = None
172+
173+
if use_padding:
174+
padding_mask = jnp.ones(
175+
shape=(batch_size, 1, 1, q_len, kv_len),
176+
dtype=jnp.bool_,
177+
)
178+
padding_mask = padding_mask.at[..., -2:].set(False)
179+
180+
manual_mask = combine_masks(padding_mask, causal_mask, dtype=q.dtype)
181+
182+
# Jax.nn path with precombined mask and is_causal = False
183+
attn_jax = nnx.dot_product_attention(
184+
query=q,
185+
key=k,
186+
value=v,
187+
mask=manual_mask,
188+
is_causal=False,
189+
deterministic=True,
190+
module=None,
191+
)
192+
193+
class DummyModule(nnx.Module):
194+
pass
195+
196+
# nnx path with padding mask and is_causal = True (internally combines them)
197+
attn_manual = nnx.dot_product_attention(
198+
query=q,
199+
key=k,
200+
value=v,
201+
mask=padding_mask,
202+
is_causal=True,
203+
deterministic=True,
204+
module=DummyModule(),
205+
)
206+
207+
np.testing.assert_allclose(attn_jax, attn_manual, atol=1e-6)
135208

136209
# TODO: add all possible constructor argument values to parameterized.product
137210
class TestLinenConsistency(parameterized.TestCase):

0 commit comments

Comments
 (0)