Skip to content

Unimplemented primitive in Pallas GPU lowering: gather #29143

Open
@kylestach

Description

@kylestach

Description

I'm seeing an issue with the following code (appears in a pallas kernel):

q_group = q_output_groups[:, None]
kv_group = kv_output_groups[None, :]
group_mask = output_group_mask[q_group, kv_group]

Here's a minimal reproducible example:

from jax.experimental import pallas as pl
import jax
import jax.numpy as jnp
from functools import partial

def compute_attn_mask_kernel(q_group_ref, k_group_ref, output_group_mask_ref, o_ref, block_q, block_k):
    q_group = pl.load(q_group_ref, (slice(None),))
    k_group = pl.load(k_group_ref, (slice(None),))
    output_group_mask = pl.load(output_group_mask_ref, (slice(None), slice(None)))

    group_mask = output_group_mask[q_group[:, None], k_group[None, :]]

    pl.store(o_ref, (slice(None), slice(None)), group_mask)

def compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k):
    batch_size, q_seq_len = q_group.shape
    k_seq_len = k_group.shape[1]
    grid = (batch_size, q_seq_len // block_q, k_seq_len // block_k)
    num_groups = output_group_mask.shape[0]
    return pl.pallas_call(
        partial(
            compute_attn_mask_kernel,
            block_q=block_q,
            block_k=block_k,
        ),
        grid=grid,
        in_specs=[
            pl.BlockSpec((None, block_q), lambda i, j, k: (i, j)),
            pl.BlockSpec((None, block_k), lambda i, j, k: (i, k)),
            pl.BlockSpec((None, num_groups, num_groups), lambda i, j, k: (i, 0, 0)),
        ],
        out_specs=[pl.BlockSpec((None, block_q, block_k), lambda i, j, k: (i, j, k))],
        out_shape=[jax.ShapeDtypeStruct(shape=(batch_size, q_seq_len, k_seq_len), dtype=jnp.bool_)],
    )(q_group, k_group, output_group_mask)

batch_size = 8
q_seq_len = 128
k_seq_len = 128
block_q = 16
block_k = 16
num_groups = 16
q_group = jnp.array(jax.random.randint(jax.random.PRNGKey(0), (batch_size, q_seq_len // block_q), 0, num_groups))
k_group = jnp.array(jax.random.randint(jax.random.PRNGKey(1), (batch_size, k_seq_len // block_k), 0, num_groups))
output_group_mask = jnp.array(
    jax.random.randint(jax.random.PRNGKey(2), (batch_size, num_groups, num_groups), 0, 2), dtype=jnp.bool_
)

attn_mask = compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k)
print(attn_mask)

Error message:

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<PATH>/test.py", line 97, in <module>
    attn_mask = compute_attn_mask(q_group, k_group, output_group_mask, block_q, block_k)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PATH>/test.py", line 69, in compute_attn_mask
    return pl.pallas_call(
           ^^^^^^^^^^^^^^^
  File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1317, in _pallas_call_lowering
    return mlir.lower_per_platform(ctx, "pallas_call",
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1313, in gpu_lowering
    return pallas_call_registration.pallas_call_lowering(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 92, in pallas_call_lowering
    lowering_result = lowering.lower_jaxpr_to_triton_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 370, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<PATH>/.venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 405, in lower_jaxpr_to_triton_ir
    raise NotImplementedError(
NotImplementedError: Unimplemented primitive in Pallas GPU lowering: gather. Please file an issue on https://github.com/jax-ml/jax/issues.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax/jaxlib 0.6.1, running on an H100, CUDA version 12.2

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpallasIssues pertaining to Pallas (GPU or TPU)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions