Skip to content

Commit e0a331d

Browse files
committed
fixed expval
1 parent 1e24ca9 commit e0a331d

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

aer_plugin/backends/aer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@ def __init__(
1717
def _exec_expval(self, circuit: QuantumCircuit) -> Results:
1818
"""Extracts expval using EstimatorV2"""
1919

20-
obs = self._metadata.get("obs")
20+
obs_list = self._metadata.get("obs")
2121
assert (
22-
obs is not None and len(obs) > 0
22+
obs_list is not None and len(obs_list) > 0 and isinstance(obs_list, list)
2323
), "You need to provide the observables to get the expectation value!"
2424

2525
# pylint: disable=import-outside-toplevel
2626
from qiskit_aer.primitives import EstimatorV2
2727
from qiskit.quantum_info import SparsePauliOp
2828

2929
estimator = EstimatorV2()
30-
result = estimator.run([(circuit, SparsePauliOp.from_list(obs))]).result()
30+
results = estimator.run(
31+
[(circuit, SparsePauliOp.from_list(obs)) for obs in obs_list]
32+
).result()
3133

32-
# to ensure the expval result will be a flot (for mypy)
33-
return result[0].data.evs # type: ignore
34+
return [result.data.evs for result in results] # type: ignore
3435

3536
def _exec_counts(self, circuit: QuantumCircuit) -> Results:
3637
"""Extracts counts using AerSimulator directly"""

aer_plugin/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Metadata = Dict[Any, Any]
1313
ResultType = str
1414
QasmFilePath = str
15-
Results = Dict[str | int, float] | float
15+
Results = Dict[str | int, float] | List[float]
1616

1717

1818
def check_backend(func):

tests/test_run.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,32 @@ def test_incorrect_quasi_dist(self):
6868
def test_correct_expval_one_obs(self):
6969
"""should raise no error and evaluate the observable ZZ to 1"""
7070
result = Plugin().execute(
71-
"aer", "./tests/valid_expval.qasm", {"obs": [("ZZ", 1)]}, "expval"
71+
"aer", "./tests/valid_expval.qasm", {"obs": [[("ZZ", 1)]]}, "expval"
7272
)
7373

74-
assert result == 1.0
74+
assert result == [1.0]
7575

76-
def test_correct_expval_two_obs(self):
76+
def test_correct_expval_two_obs_same_pub(self):
7777
"""should raise no error and sum up all expectation values (total=3)"""
7878
result = Plugin().execute(
7979
"aer",
8080
"./tests/valid_expval.qasm",
81-
{"obs": [("ZI", 1), ("IZ", 1), ("ZZ", 1)]},
81+
{"obs": [[("ZI", 1), ("IZ", 1), ("ZZ", 1)]]},
8282
"expval",
8383
)
8484

85-
assert result == 3.0
85+
assert result == [3.0]
86+
87+
def test_correct_expval_one_obs_per_pub(self):
88+
"""should return a list of expectation values with size 2"""
89+
result = Plugin().execute(
90+
"aer",
91+
"./tests/valid_expval.qasm",
92+
{"obs": [[("ZI", 1)], [("ZI", 1)]]},
93+
"expval",
94+
)
95+
96+
assert result == [1.0, 1.0]
8697

8798
def test_incorrect_expval_no_obs(self):
8899
"""should raise an error, once there's no observables defined."""
@@ -94,6 +105,6 @@ def test_incorrect_expval_circuit_with_measurements(self):
94105
"""should raise no error, once the circuit has measurements but we can evaluate the expval as well."""
95106

96107
result = Plugin().execute(
97-
"aer", "./tests/expval_measurements.qasm", {"obs": [("II", 1)]}, "expval"
108+
"aer", "./tests/expval_measurements.qasm", {"obs": [[("II", 1)]]}, "expval"
98109
)
99-
assert result == 1.0
110+
assert result == [1.0]

0 commit comments

Comments
 (0)