Skip to content

[BUG] Confusing error message when sampling an ExpectationMP with Catalyst #8378

@isaacdevlugt

Description

@isaacdevlugt

Expected behavior

I expect that the code below gives an error that is more indicative of what's truly happening.

Actual behavior

Raises an error about an MCM being sampled.

Additional information

No response

Source code

import pennylane as qml

dev = qml.device("lightning.qubit", wires=1)

qml.capture.enable()

@qml.qjit
@qml.set_shots(1000)
@qml.qnode(dev, mcm_method="one-shot")
def circuit():
    qml.H(0)
    m0 = qml.measure(0, postselect=1)
    return qml.sample(qml.expval(qml.Y(0)))

circuit()

Tracebacks

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[6], line 7
      3 dev = qml.device("lightning.qubit", wires=1)
      5 qml.capture.enable()
----> 7 @qml.qjit
      8 @qml.set_shots(1000)
      9 #@qml.qnode(dev, mcm_method="one-shot", postselect_mode="fill-shots")
     10 @qml.qnode(dev, mcm_method="one-shot")
     11 def circuit():
     12     qml.H(0)
     13     m0 = qml.measure(0, postselect=1)

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/compiler/qjit_api.py:297, in qjit(fn, compiler, *args, **kwargs)
    295 compilers = AvailableCompilers.names_entrypoints
    296 qjit_loader = compilers[compiler]["qjit"].load()
--> 297 return qjit_loader(fn=fn, *args, **kwargs)

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/catalyst/frontend/catalyst/jit.py:496, in qjit(fn, autograph, autograph_include, async_qnodes, target, keep_intermediate, verbose, logfile, pipelines, static_argnums, static_argnames, abstracted_axes, disable_assertions, seed, circuit_transform_pipeline, pass_plugins, dialect_plugins)
    493 if fn is None:
    494     return functools.partial(qjit, **kwargs)
--> 496 return QJIT(fn, CompileOptions(**kwargs))

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/Documents/catalyst/frontend/catalyst/jit.py:559, in QJIT.__init__(self, fn, compile_options)
    557 # Static arguments require values, so we cannot AOT compile.
    558 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 559     self.aot_compile()
    561 super().__init__("user_function")

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/catalyst/frontend/catalyst/jit.py:612, in QJIT.aot_compile(self)
    610 # TODO: awkward, refactor or redesign the target feature
    611 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
--> 612     self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    613         self.user_sig or ()
    614     )
    616 if self.compile_options.target in ("mlir", "binary"):
    617     self.mlir_module = self.generate_ir()

File ~/Documents/catalyst/frontend/catalyst/debug/instruments.py:145, in instrument.<locals>.wrapper(*args, **kwargs)
    142 @functools.wraps(fn)
    143 def wrapper(*args, **kwargs):
    144     if not InstrumentSession.active:
--> 145         return fn(*args, **kwargs)
    147     with ResultReporter(stage_name, has_finegrained) as reporter:
    148         self = args[0]

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/catalyst/frontend/catalyst/jit.py:730, in QJIT.capture(self, args, **kwargs)
    722 if qml.capture.enabled():
    723     with Patcher(
    724         (
    725             jax._src.interpreters.partial_eval,  # pylint: disable=protected-access
   (...)    728         ),
    729     ):
--> 730         return trace_from_pennylane(
    731             self.user_function,
    732             static_argnums,
    733             dynamic_args,
    734             abstracted_axes,
    735             full_sig,
    736             kwargs,
    737             debug_info=dbg,
    738         )
    740 def closure(qnode, *args, **kwargs):
    741     params = {}

File ~/Documents/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py:927, in trace_from_pennylane(fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info)
    924         fn.static_argnums = static_argnums
    926     plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
--> 927     jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs)
    929 return jaxpr, out_type, out_treedef, sig

    [... skipping hidden 15 frame]

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/capture/base_interpreter.py:370, in PlxprInterpreter.eval(self, jaxpr, consts, *args)
    368 if custom_handler:
    369     invals = [self.read(invar) for invar in eqn.invars]
--> 370     outvals = custom_handler(self, *invals, **eqn.params)
    371 elif getattr(primitive, "prim_type", "") == "operator":
    372     outvals = self.interpret_operation_eqn(eqn)

File ~/Documents/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py:257, in handle_qnode(self, qnode, device, shots_len, execution_config, qfunc_jaxpr, n_consts, batch_dims, *args)
    254     gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])]
    255     setattr(qnode, "decompose_gatesets", [gateset])
--> 257 return quantum_kernel_p.bind(
    258     wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info),
    259     *non_const_args,
    260     qnode=qnode,
    261     pipeline=self._pass_pipeline,
    262 )

    [... skipping hidden 7 frame]

File ~/Documents/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py:241, in handle_qnode.<locals>.calling_convention(*args)
    237 self.init_qreg = QubitHandler(qreg, self.qubit_index_recorder)
    238 converter = PLxPRToQuantumJaxprInterpreter(
    239     device, shots, self.init_qreg, {}, self.qubit_index_recorder
    240 )
--> 241 retvals = converter(closed_jaxpr, *args)
    242 self.init_qreg.insert_all_dangling_qubits()
    243 qdealloc_p.bind(self.init_qreg.get())

File ~/Documents/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py:553, in PLxPRToQuantumJaxprInterpreter.__call__(self, jaxpr, *args)
    547 def __call__(self, jaxpr, *args):
    548     """
    549     Execute this interpreter with this arguments.
    550     We expect this to be a flat function (i.e., always takes *args as inputs
    551     and no **kwargs) and the results is a sequence of values
    552     """
--> 553     return self.eval(jaxpr.jaxpr, jaxpr.consts, *args)

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/capture/base_interpreter.py:374, in PlxprInterpreter.eval(self, jaxpr, consts, *args)
    372     outvals = self.interpret_operation_eqn(eqn)
    373 elif getattr(primitive, "prim_type", "") == "measurement":
--> 374     outvals = self.interpret_measurement_eqn(eqn)
    375 else:
    376     invals = [self.read(invar) for invar in eqn.invars]

File ~/.virtualenvs/catalyst-latest/lib/python3.11/site-packages/pennylane/capture/base_interpreter.py:330, in PlxprInterpreter.interpret_measurement_eqn(self, eqn)
    328 with qml.QueuingManager.stop_recording():
    329     mp = eqn.primitive.impl(*invals, **eqn.params)
--> 330 return self.interpret_measurement(mp)

File ~/Documents/catalyst/frontend/catalyst/from_plxpr/from_plxpr.py:508, in PLxPRToQuantumJaxprInterpreter.interpret_measurement(self, measurement)
    500     raise NotImplementedError(
    501         "from_plxpr does not yet support measurements with manual eigvals."
    502     )
    503 if (
    504     measurement.mv is not None
    505     or measurement.obs is not None
    506     and not isinstance(measurement.obs, qml.operation.Operator)
    507 ):
--> 508     raise NotImplementedError("Measurements of mcms are not yet supported.")
    510 if measurement.obs:
    511     obs = self._obs(measurement.obs)

NotImplementedError: Measurements of mcms are not yet supported.

System information

Name: pennylane
Version: 0.43.0.dev68
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: 
Author: 
Author-email: 
License-Expression: Apache-2.0
Location: /Users/isaac/.virtualenvs/catalyst-latest/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing_extensions
Required-by: amazon-braket-pennylane-plugin, pennylane_catalyst, pennylane_lightning, pennylane_lightning_kokkos

Platform info:           macOS-15.7-arm64-arm-64bit
Python version:          3.11.13
Numpy version:           2.3.1
Scipy version:           1.16.0
JAX version:             0.6.2
Installed devices:
- nvidia.custatevec (pennylane_catalyst-0.13.0.dev65)
- nvidia.cutensornet (pennylane_catalyst-0.13.0.dev65)
- oqc.cloud (pennylane_catalyst-0.13.0.dev65)
- softwareq.qpp (pennylane_catalyst-0.13.0.dev65)
- default.clifford (pennylane-0.43.0.dev68)
- default.gaussian (pennylane-0.43.0.dev68)
- default.mixed (pennylane-0.43.0.dev68)
- default.qubit (pennylane-0.43.0.dev68)
- default.qutrit (pennylane-0.43.0.dev68)
- default.qutrit.mixed (pennylane-0.43.0.dev68)
- default.tensor (pennylane-0.43.0.dev68)
- null.qubit (pennylane-0.43.0.dev68)
- reference.qubit (pennylane-0.43.0.dev68)
- braket.aws.ahs (amazon-braket-pennylane-plugin-1.31.3)
- braket.aws.qubit (amazon-braket-pennylane-plugin-1.31.3)
- braket.local.ahs (amazon-braket-pennylane-plugin-1.31.3)
- braket.local.qubit (amazon-braket-pennylane-plugin-1.31.3)
- lightning.kokkos (pennylane_lightning_kokkos-0.43.0.dev36)
- lightning.qubit (pennylane_lightning-0.43.0.dev36)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐛Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions