-
Notifications
You must be signed in to change notification settings - Fork 1
Update PyQrack.run
to match RFC
#225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import List, TypeVar, ParamSpec | ||
from typing import Any, TypeVar, Iterable | ||
from dataclasses import field, dataclass | ||
|
||
from kirin import ir | ||
|
@@ -13,9 +13,6 @@ | |
) | ||
from bloqade.analysis.address import AnyAddress, AddressAnalysis | ||
|
||
Params = ParamSpec("Params") | ||
RetType = TypeVar("RetType") | ||
|
||
|
||
@dataclass | ||
class PyQrack: | ||
|
@@ -36,7 +33,9 @@ | |
{**_default_pyqrack_args(), **self.pyqrack_options} | ||
) | ||
|
||
def _get_interp(self, mt: ir.Method[Params, RetType]): | ||
RetType = TypeVar("RetType") | ||
|
||
def _get_interp(self, mt: ir.Method[..., RetType]): | ||
if self.dynamic_qubits: | ||
|
||
options = self.pyqrack_options.copy() | ||
|
@@ -64,49 +63,51 @@ | |
|
||
def run( | ||
self, | ||
mt: ir.Method[Params, RetType], | ||
*args: Params.args, | ||
**kwargs: Params.kwargs, | ||
) -> RetType: | ||
mt: ir.Method[..., RetType], | ||
*, | ||
shots: int = 1, | ||
args: tuple[Any, ...] = (), | ||
kwargs: dict[str, Any] = {}, | ||
return_iterator: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the use case of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This goes back to the discussion of memory conservation, if you return a register that will mean the full state will persist in memory if you store the simulator object in a list whereas if you return an iterator you have the option of only using the simulator object for that iteration and the gc can clean it up. Also, we get this for free since we need to loop for each shot anyway its just a slight difference in implementation. |
||
) -> RetType | list[RetType] | Iterable[RetType]: | ||
"""Run the given kernel method on the PyQrack simulator. | ||
|
||
Args | ||
mt (Method): | ||
The kernel method to run. | ||
shots (int): | ||
The number of shots to run the simulation for. | ||
Defaults to 1. | ||
args (tuple[Any, ...]): | ||
Positional arguments to pass to the kernel method. | ||
Defaults to (). | ||
kwargs (dict[str, Any]): | ||
Keyword arguments to pass to the kernel method. | ||
Defaults to {}. | ||
return_iterator (bool): | ||
Whether to return an iterator that yields results for each shot. | ||
Defaults to False. if False, a list of results is returned. | ||
|
||
Returns | ||
The result of the kernel method, if any. | ||
|
||
""" | ||
fold = Fold(mt.dialects) | ||
fold(mt) | ||
return self._get_interp(mt).run(mt, args, kwargs) | ||
|
||
def multi_run( | ||
self, | ||
mt: ir.Method[Params, RetType], | ||
_shots: int, | ||
*args: Params.args, | ||
**kwargs: Params.kwargs, | ||
) -> List[RetType]: | ||
"""Run the given kernel method on the PyQrack `_shots` times, caching analysis results. | ||
|
||
Args | ||
mt (Method): | ||
The kernel method to run. | ||
_shots (int): | ||
The number of times to run the kernel method. | ||
|
||
Returns | ||
List of results of the kernel method, one for each shot. | ||
RetType | list[RetType] | Iterable[RetType]: | ||
The result of the simulation. If `return_iterator` is True, | ||
an iterator that yields results for each shot is returned. | ||
Otherwise, a list of results is returned if `shots > 1`, or | ||
a single result is returned if `shots == 1`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if returning different data type here is a good idea. should we always return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a typo It should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One option is we just always return the
|
||
|
||
""" | ||
fold = Fold(mt.dialects) | ||
fold(mt) | ||
|
||
interpreter = self._get_interp(mt) | ||
batched_results = [] | ||
for _ in range(_shots): | ||
batched_results.append(interpreter.run(mt, args, kwargs)) | ||
|
||
return batched_results | ||
def run_shots(): | ||
for _ in range(shots): | ||
yield interpreter.run(mt, args, kwargs) | ||
|
||
if shots == 1: | ||
return interpreter.run(mt, args, kwargs) | ||
elif return_iterator: | ||
return run_shots() | ||
else: | ||
return list(run_shots()) |
Uh oh!
There was an error while loading. Please reload this page.