Skip to content

Commit 2c6428c

Browse files
authored
feature: Program sets (#214)
1 parent c84b7bb commit 2c6428c

File tree

7 files changed

+343
-82
lines changed

7 files changed

+343
-82
lines changed

qiskit_braket_provider/providers/braket_backend.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from collections.abc import Iterable
99
from typing import Optional, Union
1010

11-
from braket.aws import AwsDevice, AwsQuantumTask, AwsQuantumTaskBatch
11+
from braket.aws import AwsDevice, AwsQuantumTask
1212
from braket.aws.queue_information import QueueDepthInfo
1313
from braket.circuits import Circuit
1414
from braket.device_schema import DeviceActionType
1515
from braket.devices import Device, LocalSimulator
16+
from braket.program_sets import ProgramSet
1617
from braket.tasks.local_quantum_task import LocalQuantumTask
1718
from qiskit import QuantumCircuit
1819
from qiskit.providers import BackendV2, Options, Provider, QubitProperties
@@ -175,7 +176,7 @@ def run(
175176
task_id=task_id,
176177
tasks=tasks,
177178
backend=self,
178-
shots=shots, # type: ignore[arg-type]
179+
shots=shots,
179180
)
180181

181182

@@ -230,6 +231,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
230231
f"QiskitBraketProvider/{version.__version__}"
231232
)
232233
self._target = aws_device_to_target(device=self._aws_device)
234+
self._supports_program_sets = (
235+
DeviceActionType.OPENQASM_PROGRAM_SET in self._aws_device.properties.action
236+
)
233237

234238
def retrieve_job(self, task_id: str) -> BraketQuantumTask:
235239
"""Return a single job submitted to AWS backend.
@@ -355,15 +359,28 @@ def run(self, run_input, verbatim: bool = False, native: bool = False, **options
355359
)
356360
for circ in circuits
357361
]
362+
shots = options.pop("shots", None)
363+
return (
364+
self._run_program_set(braket_circuits, shots, **options)
365+
if self._supports_program_sets and shots != 0
366+
else self._run_batch(braket_circuits, shots, **options)
367+
)
358368

359-
batch_task: AwsQuantumTaskBatch = self._device.run_batch(
360-
braket_circuits, **options
369+
def _run_program_set(
370+
self, braket_circuits: list[Circuit], shots: Optional[int], **options
371+
):
372+
program_set = ProgramSet(braket_circuits, shots_per_executable=shots)
373+
task = self._aws_device.run(program_set, **options)
374+
return BraketQuantumTask(
375+
task_id=task.id, tasks=task, backend=self, shots=program_set.total_shots
361376
)
377+
378+
def _run_batch(self, braket_circuits: list[Circuit], shots: int, **options):
379+
batch_task = self._aws_device.run_batch(braket_circuits, shots=shots, **options)
362380
tasks: list[AwsQuantumTask] = batch_task.tasks
363381
task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks)
364-
365382
return BraketQuantumTask(
366-
task_id=task_id, tasks=tasks, backend=self, shots=options.get("shots")
383+
task_id=task_id, tasks=tasks, backend=self, shots=shots
367384
)
368385

369386

qiskit_braket_provider/providers/braket_quantum_task.py

Lines changed: 93 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,69 @@
11
"""Amazon Braket task."""
22

33
from datetime import datetime
4-
from typing import List, Optional, Union
4+
from typing import List, Union
55

66
from braket.aws import AwsQuantumTask, AwsQuantumTaskBatch
77
from braket.aws.queue_information import QuantumTaskQueueInfo
8+
from braket.tasks import GateModelQuantumTaskResult, QuantumTask
89
from braket.tasks.local_quantum_task import LocalQuantumTask
910
from qiskit.providers import BackendV2, JobStatus, JobV1
1011
from qiskit.quantum_info import Statevector
1112
from qiskit.result import Result
1213
from qiskit.result.models import ExperimentResult, ExperimentResultData
1314

15+
_TASK_STATUS_MAP = {
16+
"INITIALIZED": JobStatus.INITIALIZING,
17+
"QUEUED": JobStatus.INITIALIZING,
18+
"FAILED": JobStatus.ERROR,
19+
"CANCELLING": JobStatus.CANCELLED,
20+
"CANCELLED": JobStatus.CANCELLED,
21+
"COMPLETED": JobStatus.DONE,
22+
"RUNNING": JobStatus.RUNNING,
23+
}
24+
1425

1526
def retry_if_result_none(result):
1627
"""Retry on result function."""
1728
return result is None
1829

1930

20-
def _get_result_from_tasks(
21-
tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]],
22-
) -> Optional[List[ExperimentResult]]:
23-
"""Returns experiment results of AWS tasks.
24-
25-
Args:
26-
tasks: AWS Quantum tasks
27-
shots: number of shots
28-
29-
Returns:
30-
List of experiment results.
31-
"""
32-
experiment_results: List[ExperimentResult] = []
33-
34-
results = AwsQuantumTaskBatch._retrieve_results(
35-
tasks, AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT
36-
)
37-
38-
# For each task we create an ExperimentResult object with the downloaded results.
39-
for task, result in zip(tasks, results):
40-
if not result:
41-
return None
42-
43-
if result.task_metadata.shots == 0:
44-
braket_statevector = result.values[
45-
result._result_types_indices[
46-
"{'type': <Type.statevector: 'statevector'>}"
47-
]
48-
]
49-
data = ExperimentResultData(
50-
statevector=Statevector(braket_statevector).reverse_qargs().data,
51-
)
52-
else:
53-
counts = {
54-
k[::-1]: v for k, v in dict(result.measurement_counts).items()
55-
} # convert to little-endian
56-
57-
data = ExperimentResultData(
58-
counts=counts,
59-
memory=[
60-
"".join(shot_result[::-1].astype(str))
61-
for shot_result in result.measurements
62-
],
63-
)
31+
def _result_from_circuit_task(
32+
task: Union[LocalQuantumTask, AwsQuantumTask], result: GateModelQuantumTaskResult
33+
) -> ExperimentResult:
34+
if not result:
35+
return None
6436

65-
experiment_result = ExperimentResult(
66-
shots=result.task_metadata.shots,
67-
success=True,
68-
status=(
69-
task.state()
70-
if isinstance(task, LocalQuantumTask)
71-
else result.task_metadata.status
72-
),
73-
data=data,
37+
if result.task_metadata.shots == 0:
38+
braket_statevector = result.values[
39+
result._result_types_indices["{'type': <Type.statevector: 'statevector'>}"]
40+
]
41+
data = ExperimentResultData(
42+
statevector=Statevector(braket_statevector).reverse_qargs().data,
43+
)
44+
else:
45+
counts = {
46+
k[::-1]: v for k, v in dict(result.measurement_counts).items()
47+
} # convert to little-endian
48+
49+
data = ExperimentResultData(
50+
counts=counts,
51+
memory=[
52+
"".join(shot_result[::-1].astype(str))
53+
for shot_result in result.measurements
54+
],
7455
)
75-
experiment_results.append(experiment_result)
7656

77-
return experiment_results
57+
return ExperimentResult(
58+
shots=result.task_metadata.shots,
59+
success=True,
60+
status=(
61+
task.state()
62+
if isinstance(task, LocalQuantumTask)
63+
else result.task_metadata.status
64+
),
65+
data=data,
66+
)
7867

7968

8069
class BraketQuantumTask(JobV1):
@@ -84,8 +73,8 @@ def __init__(
8473
self,
8574
task_id: str,
8675
backend: BackendV2,
87-
tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]],
88-
**metadata: Optional[dict],
76+
tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask], AwsQuantumTask],
77+
**metadata,
8978
):
9079
"""BraketQuantumTask for execution of circuits on Amazon Braket or locally.
9180
@@ -147,6 +136,8 @@ def queue_position(self) -> QuantumTaskQueueInfo:
147136
queue_position=None, message='Task is in COMPLETED status. AmazonBraket does
148137
not show queue position for this status.')
149138
"""
139+
if isinstance(self._tasks, QuantumTask):
140+
return self._tasks.queue_position()
150141
for task in self._tasks:
151142
if isinstance(task, LocalQuantumTask):
152143
raise NotImplementedError(
@@ -159,7 +150,43 @@ def task_id(self) -> str:
159150
return self._task_id
160151

161152
def result(self) -> Result:
162-
experiment_results = _get_result_from_tasks(tasks=self._tasks)
153+
tasks = self._tasks
154+
if isinstance(tasks, QuantumTask):
155+
# Guaranteed to be program set result
156+
experiment_results = [
157+
ExperimentResult(
158+
shots=len((executable_result := program_result[0]).measurements),
159+
success=True,
160+
data=ExperimentResultData(
161+
counts=executable_result.counts,
162+
memory=[
163+
"".join(shot_result[::-1].astype(str))
164+
for shot_result in executable_result.measurements
165+
],
166+
),
167+
)
168+
for program_result in tasks.result()
169+
]
170+
status = tasks.state()
171+
return Result(
172+
backend_name=self._backend.name,
173+
backend_version=self._backend.version,
174+
job_id=self._task_id,
175+
qobj_id=0,
176+
success=status not in AwsQuantumTask.NO_RESULT_TERMINAL_STATES,
177+
results=experiment_results,
178+
status=status,
179+
)
180+
181+
experiment_results = [
182+
_result_from_circuit_task(task, result)
183+
for task, result in zip(
184+
tasks,
185+
AwsQuantumTaskBatch._retrieve_results(
186+
tasks, AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT
187+
),
188+
)
189+
]
163190
status = self.status(use_cached_value=True)
164191

165192
return Result(
@@ -173,10 +200,15 @@ def result(self) -> Result:
173200
)
174201

175202
def cancel(self):
176-
for task in self._tasks:
177-
task.cancel()
203+
if isinstance(self._tasks, QuantumTask):
204+
self._tasks.cancel()
205+
else:
206+
for task in self._tasks:
207+
task.cancel()
178208

179209
def status(self, use_cached_value: bool = False):
210+
if isinstance(self._tasks, QuantumTask):
211+
return _TASK_STATUS_MAP[self._tasks.state()]
180212
braket_tasks_states = [
181213
(
182214
task.state()

qiskit_braket_provider/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Qiskit-Braket provider version."""
22

3-
__version__ = "0.4.6"
3+
__version__ = "0.5.0"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
certifi>=2021.5.30
22
qiskit>=0.34.2, <2.0
33
qiskit-ionq>=0.5.2
4-
amazon-braket-sdk>=1.76.0
4+
amazon-braket-sdk>=1.97.0
55

66
setuptools>=40.1.0
77
numpy>=1.3

0 commit comments

Comments
 (0)