Skip to content

HLO verifier added between pre-scheduling and post-scheduling pipeline breaking compilation with mixed precision #32222

@shenlongtang

Description

@shenlongtang

The newly added HLO verifier here (commit) now breaks compilation for collective-permute with mixed precision.

Is this intended?

Steps to reproduce with jax

import jax
import jax.numpy as jnp

perm = [(0, 1), (1, 2), (2, 3), (3, 0)]

def permute_ring_fn(x):
  return jax.lax.ppermute(x, axis_name='i', perm=perm)

input_a = jnp.arange(16, dtype=jnp.bfloat16).reshape(4, 4)
input_b = jnp.arange(16, 32, dtype=jnp.float32).reshape(4, 4)

permute_ring = jax.pmap(permute_ring_fn, axis_name='i')
permute_ring.lower((input_a, input_b)).compile()

Expected result (confirmed in jax 0.5.3 with older XLA commit) - function should be compiled successfully.
Actual result (when using jax 0.7.1 with recent XLA commit) -

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/stages.py", line 628, in compile
    self._lowering.compile(**kw),  # pytype: disable=wrong-keyword-args
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/profiler.py", line 364, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/interpreters/pxla.py", line 1000, in compile
    executable = UnloadedPmapExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/interpreters/pxla.py", line 1165, in from_hlo
    compiled = compiler.compile_or_get_cached(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 494, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 762, in _compile_and_write_cache
    executable = backend_compile_and_load(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/profiler.py", line 364, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 388, in backend_compile_and_load
    raise e
  File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 378, in backend_compile_and_load
    return backend.compile_and_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: during context [hlo verifier]: Seen floating point types of different precisions in %collective-permute-start = ((bf16[4]{0}, f32[4]{0}), (bf16[4]{0}, f32[4]{0})) collective-permute-start(%bitcast.4, %bitcast.2.0), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}, metadata={scheduling_name="collective-permute-start"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":false,"is_pipelined":false,"backend":"DEFAULT"},"force_earliest_schedule":false,"reification_cost":[]}, but mixed precision is disallowed.

System Info

$ nvidia-smi
Fri Oct  3 20:41:17 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01             Driver Version: 550.163.01     CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:10:1C.0 Off |                    0 |
| N/A   65C    P0            108W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00000000:10:1D.0 Off |                    0 |
| N/A   58C    P0            118W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          On  |   00000000:20:1C.0 Off |                    0 |
| N/A   66C    P0            116W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          On  |   00000000:20:1D.0 Off |                    0 |
| N/A   56C    P0            103W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100-SXM4-40GB          On  |   00000000:90:1C.0 Off |                    0 |
| N/A   65C    P0            111W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100-SXM4-40GB          On  |   00000000:90:1D.0 Off |                    0 |
| N/A   55C    P0            105W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100-SXM4-40GB          On  |   00000000:A0:1C.0 Off |                    0 |
| N/A   67C    P0            116W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100-SXM4-40GB          On  |   00000000:A0:1D.0 Off |                    0 |
| N/A   59C    P0            113W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

Metadata

Metadata

Labels

err:BuildBuild or compilation failed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions