Skip to content

Commit 91b13ae

Browse files
committed
Use foldl and function to implement condition with list of measurements
1 parent 17dc7d5 commit 91b13ae

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/bloqade/squin/cirq/lowering.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,28 @@ def visit_ClassicallyControlledOperation(
173173
measurement_outcome = state.current_frame.defs[key]
174174

175175
if measurement_outcome.type.is_subseteq(ilist.IListType):
176-
# TODO: need to represent Any(measurement_outcome) here
177-
raise NotImplementedError("TODO")
176+
# NOTE: there is currently no convenient ilist.any method, so we need to use foldl
177+
# with a simple function that just does an or
178+
179+
def bool_op_or(x: bool, y: bool) -> bool:
180+
return x or y
181+
182+
f_code = state.current_frame.push(
183+
lowering.Python(self.dialects).python_function(bool_op_or)
184+
)
185+
fn = ir.Method(
186+
mod=None,
187+
py_func=bool_op_or,
188+
sym_name="bool_op_or",
189+
arg_names=[],
190+
dialects=self.dialects,
191+
code=f_code,
192+
)
193+
f_const = state.current_frame.push(py.constant.Constant(fn))
194+
init_val = state.current_frame.push(py.Constant(False)).result
195+
condition = state.current_frame.push(
196+
ilist.Foldl(f_const.result, measurement_outcome, init=init_val)
197+
).result
178198
else:
179199
condition = measurement_outcome
180200

test/squin/cirq/test_cirq_to_squin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,12 @@ def test_classical_control_register():
196196

197197

198198
def test_multiple_classical_controls(run_sim: bool = False):
199-
# TODO: test combination of list & single measurement here
200199
q = cirq.LineQubit.range(2)
201200
q2 = cirq.GridQubit(0, 1)
202201
circuit = cirq.Circuit(
203202
cirq.H(q[0]),
204203
cirq.H(q2),
205-
cirq.measure(q[0], key="test"),
204+
cirq.measure(q, key="test"),
206205
cirq.measure(q2),
207206
cirq.X(q[1]).with_classical_controls("test", "q(0, 1)"),
208207
cirq.measure(q[1]),

0 commit comments

Comments
 (0)