-
Notifications
You must be signed in to change notification settings - Fork 662
Open
Labels
err:BuildBuild or compilation failedBuild or compilation failed
Description
The newly added HLO verifier here (commit) now breaks compilation for collective-permute with mixed precision.
Is this intended?
Steps to reproduce with jax
import jax
import jax.numpy as jnp
perm = [(0, 1), (1, 2), (2, 3), (3, 0)]
def permute_ring_fn(x):
return jax.lax.ppermute(x, axis_name='i', perm=perm)
input_a = jnp.arange(16, dtype=jnp.bfloat16).reshape(4, 4)
input_b = jnp.arange(16, 32, dtype=jnp.float32).reshape(4, 4)
permute_ring = jax.pmap(permute_ring_fn, axis_name='i')
permute_ring.lower((input_a, input_b)).compile()
Expected result (confirmed in jax 0.5.3 with older XLA commit) - function should be compiled successfully.
Actual result (when using jax 0.7.1 with recent XLA commit) -
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/stages.py", line 628, in compile
self._lowering.compile(**kw), # pytype: disable=wrong-keyword-args
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/profiler.py", line 364, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/interpreters/pxla.py", line 1000, in compile
executable = UnloadedPmapExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/interpreters/pxla.py", line 1165, in from_hlo
compiled = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 494, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 762, in _compile_and_write_cache
executable = backend_compile_and_load(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/profiler.py", line 364, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 388, in backend_compile_and_load
raise e
File "/scratch/shentang-sandbox/code/jax_071/jax/_src/compiler.py", line 378, in backend_compile_and_load
return backend.compile_and_load(
^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: during context [hlo verifier]: Seen floating point types of different precisions in %collective-permute-start = ((bf16[4]{0}, f32[4]{0}), (bf16[4]{0}, f32[4]{0})) collective-permute-start(%bitcast.4, %bitcast.2.0), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}, metadata={scheduling_name="collective-permute-start"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":false,"is_pipelined":false,"backend":"DEFAULT"},"force_earliest_schedule":false,"reification_cost":[]}, but mixed precision is disallowed.
System Info
$ nvidia-smi
Fri Oct 3 20:41:17 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-SXM4-40GB On | 00000000:10:1C.0 Off | 0 |
| N/A 65C P0 108W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100-SXM4-40GB On | 00000000:10:1D.0 Off | 0 |
| N/A 58C P0 118W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A100-SXM4-40GB On | 00000000:20:1C.0 Off | 0 |
| N/A 66C P0 116W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A100-SXM4-40GB On | 00000000:20:1D.0 Off | 0 |
| N/A 56C P0 103W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA A100-SXM4-40GB On | 00000000:90:1C.0 Off | 0 |
| N/A 65C P0 111W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA A100-SXM4-40GB On | 00000000:90:1D.0 Off | 0 |
| N/A 55C P0 105W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA A100-SXM4-40GB On | 00000000:A0:1C.0 Off | 0 |
| N/A 67C P0 116W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA A100-SXM4-40GB On | 00000000:A0:1D.0 Off | 0 |
| N/A 59C P0 113W / 400W | 1MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
Metadata
Metadata
Assignees
Labels
err:BuildBuild or compilation failedBuild or compilation failed