-
Notifications
You must be signed in to change notification settings - Fork 706
Open
Labels
bug 🐛Something isn't workingSomething isn't working
Description
Expected behavior
We expect to be able to decorate a qnode that contains an Operator with a decomposition that makes use of qml.cond with qml.qjit.
Actual behavior
A KeyError is raised in very simple cases when a cond is used in a decomposition in a qml.qjit context.
Additional information
This error was noticed when a Catalyst compilation failed to run due to a tracing bug. The breaking change in PennyLane dev24 was identified and the error has been isolated. The following code minimally reproduces the KeyError. It seems to occur when using qml.cond with qml.qjit. A similar issue seems to exist for jax.lax.cond here: #2768?
Source code
import pennylane as qml
from pennylane.operation import Operation
from pennylane.ops import cond, RX
class Debug(Operation):
@staticmethod
def compute_decomposition(wires):
angle = cond(True, lambda: 0.0, lambda: 1.0)()
return [RX(angle, wires=wires)]
@qml.qjit
@qml.qnode(qml.device("null.qubit", wires=2))
def circuit_sv():
Debug(wires=[0])
return qml.expval(qml.PauliX(0))Tracebacks
Traceback (most recent call last):
File "/Users/marcus.edwards/Documents/pennylane/tmp/temp.py", line 242, in <module>
@qml.qjit
File "/Users/marcus.edwards/Documents/pennylane/pennylane/compiler/qjit_api.py", line 297, in qjit
return qjit_loader(fn=fn, *args, **kwargs)
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 502, in qjit
return QJIT(fn, CompileOptions(**kwargs))
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 564, in __init__
self.aot_compile()
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 622, in aot_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 770, in capture
jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 652, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py", line 484, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 760, in closure
return QFunc.__call__(
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 338, in __call__
res_flat = quantum_kernel_p.bind(
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 307, in _eval_quantum
trace_result = trace_quantum_function(
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 1696, in trace_quantum_function
transformed_results, classical_return_indices, num_mcm = _trace_quantum_step(
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 1624, in _trace_quantum_step
qrp_out = trace_quantum_operations(
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 881, in trace_quantum_operations
qrp2 = bind_native_operation(qrp, op, [], [])
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 848, in bind_native_operation
qubits2 = qinst_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: KeyError: Var(id=13385235584):float64[]
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/marcus.edwards/Documents/pennylane/tmp/temp.py", line 242, in <module>
@qml.qjit
^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/compiler/qjit_api.py", line 297, in qjit
return qjit_loader(fn=fn, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 502, in qjit
return QJIT(fn, CompileOptions(**kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 564, in __init__
self.aot_compile()
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 622, in aot_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 770, in capture
jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_tracer.py", line 652, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py", line 484, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py", line 354, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/jit.py", line 760, in closure
return QFunc.__call__(
^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/Documents/pennylane/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 338, in __call__
res_flat = quantum_kernel_p.bind(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 2690, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 552, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 2695, in bind_with_trace
return trace.process_call(self, fun, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2050, in process_call
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/profiler.py", line 354, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/catalyst/qfunc.py", line 328, in _eval_quantum
res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 630, in eval_jaxpr
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages/jax/_src/core.py", line 613, in read
return v.val if isinstance(v, Literal) else env[v]
~~~^^^
KeyError: Var(id=13385235584):float64[]System information
Name: pennylane
Version: 0.44.0.dev20
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page:
Author:
Author-email:
License-Expression: Apache-2.0
Location: /Users/marcus.edwards/miniforge3/envs/pennylane/lib/python3.12/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: amazon-braket-pennylane-plugin, PennyLane-qiskit, pennylane_catalyst, pennylane_lightning, pennylane_lightning_kokkos
Platform info: macOS-26.1-arm64-arm-64bit
Python version: 3.12.12
Numpy version: 1.26.4
Scipy version: 1.15.2
JAX version: 0.6.2
Installed devices:
- default.clifford (pennylane-0.44.0.dev20)
- default.gaussian (pennylane-0.44.0.dev20)
- default.mixed (pennylane-0.44.0.dev20)
- default.qubit (pennylane-0.44.0.dev20)
- default.qutrit (pennylane-0.44.0.dev20)
- default.qutrit.mixed (pennylane-0.44.0.dev20)
- default.tensor (pennylane-0.44.0.dev20)
- null.qubit (pennylane-0.44.0.dev20)
- reference.qubit (pennylane-0.44.0.dev20)
- lightning.qubit (pennylane_lightning-0.44.0.dev11)
- nvidia.custatevec (pennylane_catalyst-0.14.0.dev24)
- nvidia.cutensornet (pennylane_catalyst-0.14.0.dev24)
- oqc.cloud (pennylane_catalyst-0.14.0.dev24)
- softwareq.qpp (pennylane_catalyst-0.14.0.dev24)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.33.5)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.33.5)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.33.5)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.33.5)
- qiskit.aer (PennyLane-qiskit-0.43.0)
- qiskit.basicaer (PennyLane-qiskit-0.43.0)
- qiskit.basicsim (PennyLane-qiskit-0.43.0)
- qiskit.remote (PennyLane-qiskit-0.43.0)
- lightning.kokkos (pennylane_lightning_kokkos-0.44.0.dev11)Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
Metadata
Metadata
Assignees
Labels
bug 🐛Something isn't workingSomething isn't working