Skip to content

Commit 7e8ea02

Browse files
rewrite
1 parent d8057b5 commit 7e8ea02

File tree

2 files changed

+111
-30
lines changed

2 files changed

+111
-30
lines changed

easycrypt.project

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@ idirs = code/jasmin/mlkem_avx2/extraction
1313
idirs = crypto-specs/common
1414
rdirs = crypto-specs/fips202
1515
rdirs = crypto-specs/ml-kem
16-
rdirs = ~/Desktop/Repos/easycrypt/examples

proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ import MLKEM_PolyAVXVec.
2727
import NTT_Avx2.
2828
import WArray136 WArray32 WArray128.
2929
import WArray512 WArray256.
30-
(*
30+
3131
(* shake assumptions *)
3232

33-
(*
33+
3434
op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t.
3535
op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t.
3636

@@ -1264,10 +1264,10 @@ do split.
12641264
+ smt().
12651265
by smt(unpackvK).
12661266
qed.
1267-
*)
1267+
12681268
(***************************************************)
12691269

1270-
*)
1270+
12711271
import WArray960 WArray1536 Array4.
12721272

12731273
module AuxPolyVecCompress10 = {
@@ -1336,13 +1336,6 @@ rp <-
13361336
return rp;
13371337
}
13381338

1339-
proc avx2_dummy(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = {
1340-
var rr : W8.t Array960.t;
1341-
bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp);
1342-
rr <@ __polyvec_compress_avx2(ctp,bp);
1343-
return rr;
1344-
}
1345-
13461339
proc avx2(bp : W16.t Array768.t) : W8.t Array960.t = {
13471340
var rr : W8.t Array960.t;
13481341
var ctp : W8.t Array1088.t <- (init (fun (i_0 : int) => W8.zero))%Array1088;
@@ -1579,13 +1572,32 @@ swap {2} 2 -1;seq 1 1 : #pre; 1: by conseq />;inline *;sim.
15791572
inline {1} 1; inline {2} 2.
15801573
wp.
15811574
while (Glob.mem{1} = stores _mem (to_uint _ctp) (take (i{2}*20) (to_list rp{2})) /\ aux{1} = 48 /\
1575+
valid_ptr (to_uint r{1}) (128 + 3 * 320) /\ r{1} = _ctp /\
15821576
={i,a,aux,sllv_indx, shuffle, shift, mask10, b2, b1, b0} /\ 0 <= i{2} <= 48); last
15831577
by auto => />;smt(Array960.size_to_list List.take_size List.take0 storesE iota0).
15841578

15851579
seq 3 3 : (#pre /\ ={lo,hi});
15861580
1: by conseq />; sim.
1587-
auto => /> &1 &2 ???;split;last by smt().
1588-
admit.
1581+
auto => /> &1 &2 ????;split;last by smt().
1582+
rewrite /storeW32 /storeW128.
1583+
apply mem_eq_ext => add.
1584+
rewrite !get_storesE !to_uintD_small /= !of_uintK /= 1,2:/# !modz_small 1..2:/#.
1585+
rewrite !size_take 1,2:/# /= !size_to_list.
1586+
case ((to_uint _ctp <= add && add < to_uint _ctp + MIN ((i{1} + 1) * 20) 960)); last by smt().
1587+
move => *.
1588+
case ((to_uint _ctp + MIN (i{1} * 20) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 16) 960).
1589+
+ move => *; rewrite ifF 1:/# ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list .
1590+
have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt().
1591+
rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota).
1592+
rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#.
1593+
case ((to_uint _ctp + MIN (i{1} * 20+16) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 20) 960).
1594+
+ move => *; rewrite ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list .
1595+
have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt().
1596+
rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota).
1597+
rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#.
1598+
case (to_uint _ctp <= add && add < to_uint _ctp + MIN (i{1} * 20) 960); last by smt().
1599+
move => *; rewrite ifF 1:/# ifF 1:/# mulrDl /= /to_list !take_mkseq 1,2:/# /= /mkseq !(nth_map witness); 1,2: smt(size_iota).
1600+
rewrite !nth_iota 1,2:/# initiE 1:/# get8_set32_directE 1,2:/# /get8 !initiE 1,2:/# /= -/WArray960.get8 get8_set128_directE 1,2:/# /get8 initiE /#.
15891601
qed.
15901602

15911603
lemma poly_reduce_noloops :
@@ -1627,23 +1639,92 @@ seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />.
16271639
inline {1} 1; inline {2} 1.
16281640
swap {1} 3 -1;swap {2} [2..3] -1; swap {1} 7 -4; swap {2} 7 -4.
16291641
seq 3 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />.
1630-
wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\
1631-
0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\
1642+
wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ valid_ptr (to_uint rp{2}) (128 + 3 * 320) /\
1643+
0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ rp{2} = _ctp /\
16321644
j{1} *4 = i{1} * 5 /\ ={aa} /\
16331645
Glob.mem{2} = stores _mem (to_uint _ctp) (take j{1} (to_list rr0{1}))); last
16341646
by auto => />; smt(Array960.size_to_list List.take_size List.take0 storesE iota0).
1635-
unroll for {1} 2; unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ???????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt().
1636-
admit.
1647+
unroll for {1} 2; unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ?????????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt().
1648+
rewrite /storeW8 /=.
1649+
apply mem_eq_ext => adr.
1650+
rewrite !to_uintD_small /= 1..16:/# !addrA.
1651+
pose x1 := truncateu8
1652+
(truncateu16
1653+
((((zeroextu64 aa{2}.[to_uint i{2}] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1654+
(of_int 32)%W8) `&`
1655+
(of_int 1023)%W64) `&`
1656+
(of_int 255)%W16).
1657+
pose x2 := truncateu8
1658+
((truncateu16
1659+
((((zeroextu64 aa{2}.[to_uint i{2} + 1] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1660+
(of_int 32)%W8) `&`
1661+
(of_int 1023)%W64) `<<`
1662+
(of_int 2)%W8) `|`
1663+
(truncateu16
1664+
((((zeroextu64 aa{2}.[to_uint i{2}] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1665+
(of_int 32)%W8) `&`
1666+
(of_int 1023)%W64) `>>`
1667+
(of_int 8)%W8)).
1668+
pose x3 := truncateu8
1669+
((truncateu16
1670+
((((zeroextu64 aa{2}.[to_uint i{2} + 2] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1671+
(of_int 32)%W8) `&`
1672+
(of_int 1023)%W64) `<<`
1673+
(of_int 4)%W8) `|`
1674+
(truncateu16
1675+
((((zeroextu64 aa{2}.[to_uint i{2} + 1] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1676+
(of_int 32)%W8) `&`
1677+
(of_int 1023)%W64) `>>`
1678+
(of_int 6)%W8)).
1679+
pose x4 := truncateu8
1680+
((truncateu16
1681+
((((zeroextu64 aa{2}.[to_uint i{2} + 3] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1682+
(of_int 32)%W8) `&`
1683+
(of_int 1023)%W64) `<<`
1684+
(of_int 6)%W8) `|`
1685+
(truncateu16
1686+
((((zeroextu64 aa{2}.[to_uint i{2} + 2] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1687+
(of_int 32)%W8) `&`
1688+
(of_int 1023)%W64) `>>`
1689+
(of_int 4)%W8)).
1690+
pose x5 := truncateu8
1691+
((((zeroextu64 aa{2}.[to_uint i{2} + 3] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
1692+
(of_int 32)%W8) `&`
1693+
(of_int 1023)%W64 `>>` (of_int 2)%W8).
1694+
rewrite !get_storesE.
1695+
case (to_uint _ctp + to_uint j{2} <= adr < to_uint _ctp + to_uint j{2} + 5); last first.
1696+
case (to_uint _ctp <= adr < to_uint _ctp + to_uint j{2}); last first.
1697+
+ move => *;do 5!(rewrite get_set_neqE_s 1:/#).
1698+
rewrite !size_take 1:/# size_to_list /= ifF 1:/# get_storesE /= size_take 1:/# size_to_list /#.
1699+
+ move => *.
1700+
move => *;do 5!(rewrite get_set_neqE_s 1:/#).
1701+
rewrite !size_take 1:/# size_to_list /= ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /= get_storesE size_take 1:/# size_mkseq /= ifT 1:/#.
1702+
by rewrite nth_take 1,2:/# nth_mkseq 1:/# /=; smt(Array960.get_setE).
1703+
move => *. rewrite size_take 1:/# size_to_list ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /=.
1704+
by smt(Array960.get_setE get_set_neqE_s get_set_eqE_s).
16371705
qed.
16381706

1639-
lemma compress10_equiv_avx2i_dummy :
1640-
equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2_dummy :
1641-
={bp,ctp} ==> ={res} ] by proc;inline *;sim;auto => />.
1642-
16431707
lemma compress10_equiv_avx2i :
16441708
equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2 :
1645-
={bp} ==> ={res} ] by admit.
1646-
1709+
={bp} ==> ={res} ].
1710+
proc => /=.
1711+
swap {2} 2 -1.
1712+
seq 1 1 : #pre; 1: by sim.
1713+
inline *;wp.
1714+
while (={i,a,bp,b0,b1,b2,mask10,shift,sllv_indx,shuffle,aux} /\ 0<=i{1} <= 48 /\ aux{1}=48 /\ (forall k, 0<=k<i{1}*20 => rp{1}.[k] = rp{2}.[k]));
1715+
last by auto => /> *; split;[ smt() | move => *; rewrite tP => *;smt()].
1716+
auto => /> &1 &2 *;split;1:smt().
1717+
move => k kbl kbh; rewrite !initiE 1,2:/# /=.
1718+
rewrite !get8_set32_directE 1..4:/#.
1719+
case (0<=k<i{2}*20).
1720+
+ move => *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=.
1721+
rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifF 1,2:/#.
1722+
by rewrite /get8 !initiE /#.
1723+
move => *;case (i{2}*20<=k<i{2}*20+16).
1724+
+ move => *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=.
1725+
by rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifT /#.
1726+
by smt().
1727+
qed.
16471728

16481729
lemma compress10_equiv_refi :
16491730
equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig_i :
@@ -1781,7 +1862,7 @@ proc change ^while{3}.11 : (sliceset256_16_256 ap1 i2 a3). by admit.
17811862
proc change 26 : (init_768_16 (fun i => if 2 * 256 <= i < 3 * 256 then aux.[i - 2 * 256] else r.[i])). by auto.
17821863

17831864
proc change 30 : (init_960_8 (fun i_0 => ctp0.[i_0 + 0])). by done.
1784-
proc change 36 : (sliceget32_8_256 pvc_shufbidx_s 0). by admit.
1865+
proc change 37 : (sliceget32_8_256 pvc_shufbidx_s 0). by admit.
17851866

17861867
proc change ^while{4}.1 : (sliceget768_16_256 a i). by admit.
17871868

@@ -1793,11 +1874,12 @@ cfold 38.
17931874
unroll for 39.
17941875
cfold 38. unroll for 24. cfold 23.
17951876
unroll for 16. cfold 15. unroll for 8. cfold 7.
1796-
1797-
proc change 552 : (init_960_8 (fun i => ctp0.[i])). by done.
1877+
admit. (*
17981878
bdep 16 16 [_bp] [bp] [ap] lane pcond.
17991879

18001880
print get256_direct.
1881+
*)
1882+
qed.
18011883

18021884
(* MAP REDUCE GOAL *)
18031885
lemma compress10_mr :
@@ -1812,7 +1894,7 @@ lemma compress10_equiv :
18121894
proof.
18131895
proc* => /=.
18141896
exlim Glob.mem{1}, ctp{1} => _mem _ctp.
1815-
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(witness,bp); }
1897+
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(bp); }
18161898
(={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==>
18171899
Glob.mem{1} = stores _mem (to_uint _ctp) (to_list r{2}))
18181900
(lift_array768 bp{1} = lift_array768 bp{2} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> Glob.mem{2} = stores _mem (to_uint _ctp) (to_list r{1}));
@@ -1830,7 +1912,7 @@ lemma compress10_equiv_i :
18301912
proof.
18311913
proc* => /=.
18321914
exlim Glob.mem{1}, ctp{1} => _mem _ctp.
1833-
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(ctp,bp); }
1915+
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(bp); }
18341916
(={bp,ctp} ==> ={r})
18351917
(lift_array768 bp{1} = lift_array768 bp{2} /\ ={ctp} ==> ={r});
18361918
[ by smt() | by smt() | by call (compress10_equiv_avx2i); auto => /> |].
@@ -1840,7 +1922,7 @@ transitivity {2} { r <@ AuxPolyVecCompress10.ref(bp); }
18401922
[ by smt() | by smt() | | by call (compress10_equiv_refi); auto => />].
18411923
by call compress10_mr; auto => />.
18421924
qed.
1843-
1925+
import InnerPKE.
18441926
lemma mlkem_correct_enc_0_avx2 mem _ctp _pkp :
18451927
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_enc_0 ~ InnerPKE.enc_derand:
18461928
valid_ptr _pkp (384*3 + 32) /\

0 commit comments

Comments
 (0)