Skip to content

Commit

Permalink
perf: Minor improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
kylechui committed Jun 4, 2024
1 parent f225de0 commit 7b67629
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 93 deletions.
6 changes: 3 additions & 3 deletions Appendix/A3_Swaps.v
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ Proof.
rewrite control_decomp.
rewrite Mmult_plus_distr_l.
rewrite Mmult_plus_distr_r.
repeat rewrite kron_mixed_product.
repeat rewrite (@kron_mixed_product 2 2 2 4 4 4).
repeat rewrite Mmult_1_l.
repeat rewrite Mmult_1_r.
rewrite swap_swap.
rewrite swap_swap at 1.
assert (swap × control (diag2 C1 c1) × swap = control (diag2 C1 c1)) by lma'.
rewrite H.
rewrite H at 1.
all: solve_WF_matrix.
Qed.

Expand Down
7 changes: 3 additions & 4 deletions Appendix/A5_ControlledUnitaries.v
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,10 @@ assert (U_block_decomp: exists (P0 P1 : Square 2), U = P0 ⊗ ∣0⟩⟨0∣ .+
assert (swap_inverse_helper: swap × (swap × U × swap) × swap = U).
{
repeat rewrite <- Mmult_assoc.
rewrite swap_swap.
rewrite Mmult_1_l. 2: apply U_unitary.
rewrite Mmult_assoc.
(* TODO: Figure out why swap_swap doesn't work here *)
lma'; solve_WF_matrix.
rewrite swap_swap.
rewrite swap_swap at 1.
rewrite Mmult_1_l, Mmult_1_r; solve_WF_matrix.
}
rewrite swap_inverse_helper in SUS_block_decomp.
rewrite SUS_block_decomp.
Expand Down
2 changes: 1 addition & 1 deletion Appendix/A6_TensorProducts.v
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ assert (outer_prod_equiv : acgate U × (a ⊗ b ⊗ g) × (a ⊗ b ⊗ g)† ×
rewrite <- Mmult_adjoint.
assert (app_helper: acgate U × (a ⊗ b ⊗ g) = psi ⊗ phi). apply acU_app.
rewrite app_helper at 1. clear app_helper.
rewrite kron_adjoint.
rewrite (@kron_adjoint 2 1 4 1).
reflexivity.
}
(* trace out ac qubits *)
Expand Down
2 changes: 1 addition & 1 deletion Appendix/A7_OtherProperties.v
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ rewrite <- acv3_id at 1.
rewrite <- Mmult_1_r with (A := acgate V3). 2: apply WF_acgate. 2: assumption.
assert (temp: WF_Unitary (bcgate V4)†). apply adjoint_unitary. apply bcgate_unitary. assumption.
destruct temp as [WF_bcv4dag bcv4dag_inv].
replace (2*2)%nat with 4%nat by lia.
replace 8%nat with (2 * 4)%nat by lia.
rewrite <- bcv4dag_inv.
rewrite adjoint_involutive.
repeat rewrite <- Mmult_assoc.
Expand Down
14 changes: 7 additions & 7 deletions Helpers/GateHelpers.v
Original file line number Diff line number Diff line change
Expand Up @@ -586,16 +586,16 @@ apply (f_equal (fun f => swapbc × f)) in acgate_c.
repeat rewrite <- Mmult_assoc in acgate_c.
rewrite swapbc_inverse in acgate_c. rewrite Mmult_1_l in acgate_c. 2: apply WF_abgate; solve_WF_matrix.
rewrite Mmult_plus_distr_l in acgate_c.
unfold swapbc in acgate_c at 2. rewrite kron_mixed_product in acgate_c.
unfold swapbc in acgate_c at 2. rewrite kron_mixed_product in acgate_c.
unfold swapbc in acgate_c at 2. rewrite (@kron_mixed_product 2 2 2 4 4 4) in acgate_c.
unfold swapbc in acgate_c at 2. rewrite (@kron_mixed_product 2 2 2 4 4 4) in acgate_c.
rewrite Mmult_1_l in acgate_c. 2: solve_WF_matrix.
rewrite Mmult_1_l in acgate_c. 2: solve_WF_matrix.
apply (f_equal (fun f => f × swapbc)) in acgate_c.
rewrite Mmult_assoc in acgate_c.
rewrite swapbc_inverse in acgate_c at 1. rewrite Mmult_1_r in acgate_c. 2: apply WF_abgate; solve_WF_matrix.
rewrite Mmult_plus_distr_r in acgate_c.
unfold swapbc in acgate_c at 2. rewrite kron_mixed_product in acgate_c.
unfold swapbc in acgate_c at 1. rewrite kron_mixed_product in acgate_c.
unfold swapbc in acgate_c at 2. rewrite (@kron_mixed_product 2 2 2 4 4 4) in acgate_c.
unfold swapbc in acgate_c at 1. rewrite (@kron_mixed_product 2 2 2 4 4 4) in acgate_c.
rewrite Mmult_1_r in acgate_c. 2: solve_WF_matrix.
rewrite Mmult_1_r in acgate_c. 2: solve_WF_matrix.
assert (swapW0_unit: WF_Unitary (swap × W0 × swap)).
Expand Down Expand Up @@ -1229,10 +1229,10 @@ acgate U = ∣0⟩⟨0∣ ⊗ TL .+ ∣1⟩⟨1∣ ⊗ BR).
rewrite abs.
unfold swapbc.
rewrite Mmult_plus_distr_l.
repeat rewrite kron_mixed_product.
repeat rewrite (@kron_mixed_product 2 2 2 4 4 4).
repeat rewrite Mmult_1_l. 2,3: solve_WF_matrix.
rewrite Mmult_plus_distr_r.
repeat rewrite kron_mixed_product.
repeat rewrite (@kron_mixed_product 2 2 2 4 4 4).
repeat rewrite Mmult_1_r. 2,3: solve_WF_matrix.
reflexivity.
}
Expand All @@ -1256,7 +1256,7 @@ rewrite Mmult_assoc in zeropassthrough.
rewrite swapbc_3q in zeropassthrough. 2: solve_WF_matrix. 2: apply x_qubit. 2: apply y_qubit.
rewrite zeropassthrough at 1.
unfold swapbc.
rewrite kron_mixed_product.
rewrite (@kron_mixed_product 2 2 1 4 4 1).
rewrite Mmult_1_l. 2: solve_WF_matrix.
reflexivity.
Qed.
Expand Down
80 changes: 32 additions & 48 deletions Helpers/SwapHelpers.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ Require Import QuantumLib.Quantum.
Require Import WFHelpers.
Require Import MatrixHelpers.

Definition swapab := swap ⊗ I 2.
Definition swapbc := I 2 ⊗ swap.
Definition swapac := swapab × swapbc × swapab.
Definition swapab : Square 8 := swap ⊗ I 2.
Definition swapbc : Square 8 := I 2 ⊗ swap.
Definition swapac : Square 8 := swapab × swapbc × swapab.

#[global] Hint Unfold swapab swapbc swapac : M_db.

Expand All @@ -28,68 +28,57 @@ Qed.

Lemma swapab_unitary : WF_Unitary swapab.
Proof.
autounfold with M_db; auto with unit_db.
solve_WF_matrix.
Qed.

Lemma swapbc_unitary : WF_Unitary swapbc.
Proof.
autounfold with M_db; auto with unit_db.
solve_WF_matrix.
Qed.

Lemma swapac_unitary : WF_Unitary swapac.
Proof.
apply Mmult_unitary.
apply Mmult_unitary.
apply swapab_unitary.
apply swapbc_unitary.
apply swapab_unitary.
solve_WF_matrix.
Qed.

#[export] Hint Resolve swapab_unitary swapbc_unitary swapac_unitary : unit_db.

Lemma swapab_inverse : swapab × swapab = I 8.
Proof.
unfold swapab.
rewrite kron_mixed_product, swap_swap.
Msimpl_light.
rewrite id_kron.
replace (2 * 2 * 2)%nat with 8%nat by lia.
rewrite (@kron_mixed_product 4 4 4 2 2 2).
rewrite swap_swap at 1.
rewrite Mmult_1_l, id_kron.
reflexivity.
apply WF_I.
Qed.

Lemma swapbc_inverse : swapbc × swapbc = I 8.
Proof.
unfold swapbc.
rewrite kron_mixed_product, swap_swap.
Msimpl_light.
rewrite id_kron.
replace (2 * (2 * 2))%nat with 8%nat by lia.
rewrite (@kron_mixed_product 2 2 2 4 4 4).
rewrite swap_swap at 1.
rewrite Mmult_1_l, id_kron.
reflexivity.
apply WF_I.
Qed.

Lemma swapac_inverse : swapac × swapac = I 8.
Proof.
apply mat_equiv_eq.
apply WF_mult. apply WF_swapac. apply WF_swapac.
apply WF_I.
unfold swapac.
repeat rewrite Mmult_assoc.
rewrite <- Mmult_assoc with (A := swapab) (B := swapab) (C:= swapbc × swapab).
rewrite <- Mmult_assoc with (A := swapbc) (B := swapab × swapab) (C:= swapbc × swapab).
rewrite swapab_inverse.
rewrite Mmult_1_r. 2: apply WF_swapbc.
rewrite <- Mmult_assoc with (A := swapbc) (B:= swapbc) (C:=swapab).
rewrite <- Mmult_assoc with (A := swapab) (B:= swapbc × swapbc) (C:=swapab).
rewrite swapbc_inverse.
rewrite Mmult_1_r. 2: apply WF_swapab.
rewrite <- swapab_inverse.
apply mat_equiv_refl.
rewrite <- Mmult_assoc with (A := swapab) (B := swapab).
rewrite swapab_inverse, Mmult_1_l.
rewrite <- Mmult_assoc with (A := swapbc).
rewrite swapbc_inverse, Mmult_1_l.
exact swapab_inverse.
all: solve_WF_matrix.
Qed.

Lemma swapab_hermitian : swapab† = swapab.
Proof.
unfold swapab.
rewrite kron_adjoint.
rewrite (@kron_adjoint 4 4 2 2).
rewrite swap_hermitian, id_adjoint_eq.
reflexivity.
Qed.
Expand All @@ -116,7 +105,7 @@ WF_Matrix a -> WF_Matrix b -> WF_Matrix c ->
Proof.
intros.
unfold swapab.
rewrite kron_mixed_product.
rewrite (@kron_mixed_product 4 4 1 2 2 1).
rewrite Mmult_1_l. 2: assumption.
rewrite swap_2q. 2,3: assumption.
reflexivity.
Expand All @@ -128,26 +117,21 @@ Lemma swapab_3gate : forall (A B C : Square 2),
Proof.
intros.
unfold swapab.
rewrite kron_mixed_product.
rewrite Mmult_1_l. 2: assumption.
rewrite kron_mixed_product.
rewrite Mmult_1_r. 2: assumption.
rewrite swap_2gate. 2: assumption. 2: assumption.
reflexivity.
do 2 rewrite (@kron_mixed_product 4 4 4 2 2 2).
rewrite swap_2gate, Mmult_1_l, Mmult_1_r.
all: solve_WF_matrix.
Qed.

Lemma swapbc_3q : forall (a b c : Vector 2),
WF_Matrix a -> WF_Matrix b -> WF_Matrix c ->
swapbc × (a ⊗ b ⊗ c) = (a ⊗ c ⊗ b).
Proof.
intros.
unfold swapbc.
rewrite kron_assoc. 2,3,4: assumption.
rewrite kron_mixed_product.
rewrite Mmult_1_l. 2: assumption.
rewrite swap_2q. 2,3: assumption.
rewrite kron_assoc. 2,3,4: assumption.
reflexivity.
intros.
unfold swapbc.
repeat rewrite kron_assoc.
rewrite (@kron_mixed_product 2 2 1 4 4 1).
rewrite swap_2q, Mmult_1_l.
all: solve_WF_matrix.
Qed.

Lemma swapbc_3gate : forall (A B C : Square 2),
Expand Down Expand Up @@ -193,7 +177,7 @@ Qed.
Lemma swapbc_sa: swapbc = (swapbc) †.
Proof.
unfold swapbc.
rewrite kron_adjoint.
rewrite (@kron_adjoint 2 2 4 4).
rewrite id_adjoint_eq.
rewrite swap_hermitian.
reflexivity.
Expand Down
57 changes: 28 additions & 29 deletions Main.v
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,24 @@ Proof.
rewrite <- kron_plus_distr_l.
unfold beta, beta_perp.
rewrite a8; auto.
assert (forall {n}, I 2 ⊗ I n = I n .⊕ I n).
{
intro n.
rewrite <- Mplus01, kron_plus_distr_r, <- (direct_sum_decomp _ _ 0 0).
all: solve_WF_matrix.
}
unfold ccu.
rewrite kron_assoc.
rewrite id_kron.
rewrite H2, control_direct_sum.
rewrite direct_sum_simplify.
split. reflexivity.
rewrite <- id_kron.
rewrite H2, control_direct_sum.
rewrite direct_sum_simplify.
split. reflexivity.
lma'.
all: solve_WF_matrix.
}
Qed.

Expand Down Expand Up @@ -1825,9 +1842,9 @@ Proof.
rewrite <- a12 with (U := V4).
repeat rewrite Mmult_adjoint.
repeat rewrite <- Mmult_assoc.
(* PERF: Why does it lag here? *)
repeat rewrite swapab_hermitian at 1.
rewrite swapab_inverse at 1.
restore_dims.
repeat rewrite swapab_hermitian.
rewrite swapab_inverse.
rewrite Mmult_1_l.
repeat rewrite Mmult_assoc.
rewrite <- Mmult_assoc with (A := swapab).
Expand Down Expand Up @@ -1888,8 +1905,9 @@ Proof.
rewrite <- a12 with (U := V4).
repeat rewrite Mmult_adjoint.
repeat rewrite <- Mmult_assoc.
repeat rewrite swapab_hermitian at 1.
rewrite swapab_inverse at 1.
restore_dims.
repeat rewrite swapab_hermitian.
rewrite swapab_inverse.
rewrite Mmult_1_l.
repeat rewrite Mmult_assoc.
rewrite <- Mmult_assoc with (A := swapab).
Expand Down Expand Up @@ -2323,7 +2341,8 @@ Proof.
{
repeat rewrite Mmult_assoc.
repeat rewrite <- Mmult_assoc with (A := swapab) (B := swapab).
repeat rewrite swapab_inverse at 1.
restore_dims.
repeat rewrite swapab_inverse.
repeat rewrite Mmult_1_l.
rewrite H1 at 1; clear H1.
repeat rewrite <- Mmult_assoc.
Expand All @@ -2334,7 +2353,7 @@ Proof.
do 2 rewrite a12 in H2.
unfold swapab in H2.
rewrite Mmult_assoc with (A := bcgate U1) in H2.
rewrite kron_mixed_product in H2.
rewrite (@kron_mixed_product 4 4 1 2 2 1) in H2.
rewrite Mmult_1_l in H2.
rewrite a10 in H2.
rewrite H2 at 1; clear H2.
Expand All @@ -2346,23 +2365,13 @@ Proof.
rewrite <- Mmult_assoc with (A := swapab).
rewrite swapab_inverse at 1.
unfold swapab.
rewrite kron_mixed_product.
rewrite (@kron_mixed_product 4 4 1 2 2 1).
rewrite a10.
Msimpl_light.
all: solve_WF_matrix.
}
assert (exists (Q0 Q1 : Square 2), U4† = ∣0⟩⟨0∣ ⊗ Q0 .+ ∣1⟩⟨1∣ ⊗ Q1 /\ WF_Unitary Q0 /\ WF_Unitary Q1).
{
(* TODO(Kyle): Use the one in the refactoring PR!! *)
assert (inner_product_kron : forall {m n} (u : Vector m) (v : Vector n),
⟨u ⊗ v, u ⊗ v⟩ = ⟨u, u⟩ * ⟨v, v⟩).
{
intros.
unfold inner_product.
rewrite (@kron_adjoint m 1 n 1).
rewrite (@kron_mixed_product 1 m 1 1 n 1).
unfold kron; reflexivity.
}
assert (⟨ x ⊗ ∣0⟩, x ⊗ ∣0⟩ ⟩ = C1).
{
rewrite inner_product_kron.
Expand All @@ -2389,7 +2398,7 @@ Proof.
unfold bcgate in H3.
rewrite kron_assoc in H3.
rewrite Mmult_assoc in H3.
repeat rewrite kron_mixed_product in H3.
repeat rewrite (@kron_mixed_product 2 2 1 4 4 1) in H3.
rewrite Mmult_1_l in H3.
symmetry.
all: solve_WF_matrix.
Expand Down Expand Up @@ -2914,16 +2923,6 @@ Proof.
assert (exists (P0 P1 : Square 2),
V1 = ∣0⟩⟨0∣ ⊗ P0 .+ ∣1⟩⟨1∣ ⊗ P1 /\ WF_Unitary P0 /\ WF_Unitary P1).
{
(* TODO(Kyle): Use the one in the refactoring PR!! *)
assert (inner_product_kron : forall {m n} (u : Vector m) (v : Vector n),
⟨u ⊗ v, u ⊗ v⟩ = ⟨u, u⟩ * ⟨v, v⟩).
{
intros.
unfold inner_product.
rewrite (@kron_adjoint m 1 n 1).
rewrite (@kron_mixed_product 1 m 1 1 n 1).
unfold kron; reflexivity.
}
assert (⟨ x ⊗ ∣0⟩, x ⊗ ∣0⟩ ⟩ = C1).
{
rewrite inner_product_kron.
Expand Down

0 comments on commit 7b67629

Please sign in to comment.