Skip to content

[BUG] Using qml.cond with qml.qjit causes KeyError #8627

@comp-phys-marc

Description

@comp-phys-marc

Expected behavior

We expect to be able to decorate a qnode that contains an Operator with a decomposition that makes use of qml.cond with qml.qjit.

Actual behavior

A KeyError is raised in very simple cases when a cond is used in a decomposition in a qml.qjit context.

Additional information

This error was noticed when a Catalyst compilation failed to run due to a tracing bug. The breaking change in PennyLane dev24 was identified and the error has been isolated. The following code minimally reproduces the KeyError. It seems to occur when using qml.cond with qml.qjit. A similar issue seems to exist for jax.lax.cond here: #2768?

Source code

import pennylane as qml
from pennylane.operation import Operation
from pennylane.ops import cond, RX

class Debug(Operation):

    @staticmethod
    def compute_decomposition(wires):
        angle = cond(True, lambda: 0.0, lambda: 1.0)()
        return [RX(angle, wires=wires)]

@qml.qjit
@qml.qnode(qml.device("null.qubit", wires=2))
def circuit_sv():
    Debug(wires=[0])
    return qml.expval(qml.PauliX(0))

Tracebacks

Traceback (most recent call last):
  File "/Users/marcus.edwards/Documents/pennylane/tmp/temp.py", line 242, in <module>
    @qml.qjit
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/compiler/qjit_api.py", line 297, in qjit
    return qjit_loader(fn=fn, *args, **kwargs)
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 564, in __init__
    self.aot_compile()
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 622, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 770, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 652, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py", line 484, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 760, in closure
    return QFunc.__call__(
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 338, in __call__
    res_flat = quantum_kernel_p.bind(
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 307, in _eval_quantum
    trace_result = trace_quantum_function(
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 1696, in trace_quantum_function
    transformed_results, classical_return_indices, num_mcm = _trace_quantum_step(
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 1624, in _trace_quantum_step
    qrp_out = trace_quantum_operations(
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 881, in trace_quantum_operations
    qrp2 = bind_native_operation(qrp, op, [], [])
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 848, in bind_native_operation
    qubits2 = qinst_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: KeyError: Var(id=13385235584):float64[]

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/marcus.edwards/Documents/pennylane/tmp/temp.py", line 242, in <module>
    @qml.qjit
     ^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/compiler/qjit_api.py", line 297, in qjit
    return qjit_loader(fn=fn, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 564, in __init__
    self.aot_compile()
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 622, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
                                                              ^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 770, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
                                        ^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 652, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py", line 484, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 760, in closure
    return QFunc.__call__(
           ^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 338, in __call__
    res_flat = quantum_kernel_p.bind(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 2690, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 552, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 2695, in bind_with_trace
    return trace.process_call(self, fun, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2050, in process_call
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 328, in _eval_quantum
    res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 630, in eval_jaxpr
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 613, in read
    return v.val if isinstance(v, Literal) else env[v]
                                                ~~~^^^
KeyError: Var(id=13385235584):float64[]

System information

Name: pennylane
Version: 0.44.0.dev20
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: 
Author: 
Author-email: 
License-Expression: Apache-2.0
Location: /Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-qiskit, pennylane_catalyst, pennylane_lightning, pennylane_lightning_kokkos

Platform info:           macOS-26.1-arm64-arm-64bit
Python version:          3.12.12
Numpy version:           1.26.4
Scipy version:           1.15.2
JAX version:             0.6.2
Installed devices:
- default.clifford (pennylane-0.44.0.dev20)
- default.gaussian (pennylane-0.44.0.dev20)
- default.mixed (pennylane-0.44.0.dev20)
- default.qubit (pennylane-0.44.0.dev20)
- default.qutrit (pennylane-0.44.0.dev20)
- default.qutrit.mixed (pennylane-0.44.0.dev20)
- default.tensor (pennylane-0.44.0.dev20)
- null.qubit (pennylane-0.44.0.dev20)
- reference.qubit (pennylane-0.44.0.dev20)
- lightning.qubit (pennylane_lightning-0.44.0.dev11)
- nvidia.custatevec (pennylane_catalyst-0.14.0.dev24)
- nvidia.cutensornet (pennylane_catalyst-0.14.0.dev24)
- oqc.cloud (pennylane_catalyst-0.14.0.dev24)
- softwareq.qpp (pennylane_catalyst-0.14.0.dev24)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.33.5)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.33.5)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.33.5)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.33.5)
- qiskit.aer (PennyLane-qiskit-0.43.0)
- qiskit.basicaer (PennyLane-qiskit-0.43.0)
- qiskit.basicsim (PennyLane-qiskit-0.43.0)
- qiskit.remote (PennyLane-qiskit-0.43.0)
- lightning.kokkos (pennylane_lightning_kokkos-0.44.0.dev11)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐛Something isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions