Open
Description
Description
I'm seeing an issue with the following code (appears in a pallas kernel):
q_group = q_output_groups[:, None]
kv_group = kv_output_groups[None, :]
group_mask = output_group_mask[q_group, kv_group]
Here's a minimal reproducible example:
from jax.experimental import pallas as pl
import jax
import jax.numpy as jnp
from functools import partial
def compute_attn_mask_kernel(q_group_ref, k_group_ref, output_group_mask_ref, o_ref, block_q, block_k):
q_group = pl.load(q_group_ref, (slice(None),))
k_group = pl.load(k_group_ref, (slice(None),))
output_group_mask = pl.load(output_group_mask_ref, (slice(None), slice(None)))
group_mask = output_group_mask[q_group[:, None], k_group[None, :]]
pl.store(o_ref, (slice(None), slice(None)), group_mask)
def compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k):
batch_size, q_seq_len = q_group.shape
k_seq_len = k_group.shape[1]
grid = (batch_size, q_seq_len // block_q, k_seq_len // block_k)
num_groups = output_group_mask.shape[0]
return pl.pallas_call(
partial(
compute_attn_mask_kernel,
block_q=block_q,
block_k=block_k,
),
grid=grid,
in_specs=[
pl.BlockSpec((None, block_q), lambda i, j, k: (i, j)),
pl.BlockSpec((None, block_k), lambda i, j, k: (i, k)),
pl.BlockSpec((None, num_groups, num_groups), lambda i, j, k: (i, 0, 0)),
],
out_specs=[pl.BlockSpec((None, block_q, block_k), lambda i, j, k: (i, j, k))],
out_shape=[jax.ShapeDtypeStruct(shape=(batch_size, q_seq_len, k_seq_len), dtype=jnp.bool_)],
)(q_group, k_group, output_group_mask)
batch_size = 8
q_seq_len = 128
k_seq_len = 128
block_q = 16
block_k = 16
num_groups = 16
q_group = jnp.array(jax.random.randint(jax.random.PRNGKey(0), (batch_size, q_seq_len // block_q), 0, num_groups))
k_group = jnp.array(jax.random.randint(jax.random.PRNGKey(1), (batch_size, k_seq_len // block_k), 0, num_groups))
output_group_mask = jnp.array(
jax.random.randint(jax.random.PRNGKey(2), (batch_size, num_groups, num_groups), 0, 2), dtype=jnp.bool_
)
attn_mask = compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k)
print(attn_mask)
Error message:
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<PATH>/test.py", line 97, in <module>
attn_mask = compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<PATH>/test.py", line 69, in compute_attn_mask
return pl.pallas_call(
^^^^^^^^^^^^^^^
File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1317, in _pallas_call_lowering
return mlir.lower_per_platform(ctx, "pallas_call",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1313, in gpu_lowering
return pallas_call_registration.pallas_call_lowering(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 92, in pallas_call_lowering
lowering_result = lowering.lower_jaxpr_to_triton_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 370, in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 405, in lower_jaxpr_to_triton_ir
raise NotImplementedError(
NotImplementedError: Unimplemented primitive in Pallas GPU lowering: gather. Please file an issue on https://github.com/jax-ml/jax/issues.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
jax/jaxlib
0.6.1, running on an H100, CUDA version 12.2