Skip to content

Commit

Permalink
rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Oct 25, 2024
1 parent d8057b5 commit 7e8ea02
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 30 deletions.
1 change: 0 additions & 1 deletion easycrypt.project
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@ idirs = code/jasmin/mlkem_avx2/extraction
idirs = crypto-specs/common
rdirs = crypto-specs/fips202
rdirs = crypto-specs/ml-kem
rdirs = ~/Desktop/Repos/easycrypt/examples
140 changes: 111 additions & 29 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ import MLKEM_PolyAVXVec.
import NTT_Avx2.
import WArray136 WArray32 WArray128.
import WArray512 WArray256.
(*

(* shake assumptions *)

(*

op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t.
op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t.

Expand Down Expand Up @@ -1264,10 +1264,10 @@ do split.
+ smt().
by smt(unpackvK).
qed.
*)

(***************************************************)

*)

import WArray960 WArray1536 Array4.

module AuxPolyVecCompress10 = {
Expand Down Expand Up @@ -1336,13 +1336,6 @@ rp <-
return rp;
}

proc avx2_dummy(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = {
var rr : W8.t Array960.t;
bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp);
rr <@ __polyvec_compress_avx2(ctp,bp);
return rr;
}

proc avx2(bp : W16.t Array768.t) : W8.t Array960.t = {
var rr : W8.t Array960.t;
var ctp : W8.t Array1088.t <- (init (fun (i_0 : int) => W8.zero))%Array1088;
Expand Down Expand Up @@ -1579,13 +1572,32 @@ swap {2} 2 -1;seq 1 1 : #pre; 1: by conseq />;inline *;sim.
inline {1} 1; inline {2} 2.
wp.
while (Glob.mem{1} = stores _mem (to_uint _ctp) (take (i{2}*20) (to_list rp{2})) /\ aux{1} = 48 /\
valid_ptr (to_uint r{1}) (128 + 3 * 320) /\ r{1} = _ctp /\
={i,a,aux,sllv_indx, shuffle, shift, mask10, b2, b1, b0} /\ 0 <= i{2} <= 48); last
by auto => />;smt(Array960.size_to_list List.take_size List.take0 storesE iota0).

seq 3 3 : (#pre /\ ={lo,hi});
1: by conseq />; sim.
auto => /> &1 &2 ???;split;last by smt().
admit.
auto => /> &1 &2 ????;split;last by smt().
rewrite /storeW32 /storeW128.
apply mem_eq_ext => add.
rewrite !get_storesE !to_uintD_small /= !of_uintK /= 1,2:/# !modz_small 1..2:/#.
rewrite !size_take 1,2:/# /= !size_to_list.
case ((to_uint _ctp <= add && add < to_uint _ctp + MIN ((i{1} + 1) * 20) 960)); last by smt().
move => *.
case ((to_uint _ctp + MIN (i{1} * 20) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 16) 960).
+ move => *; rewrite ifF 1:/# ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list .
have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt().
rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota).
rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#.
case ((to_uint _ctp + MIN (i{1} * 20+16) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 20) 960).
+ move => *; rewrite ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list .
have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt().
rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota).
rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#.
case (to_uint _ctp <= add && add < to_uint _ctp + MIN (i{1} * 20) 960); last by smt().
move => *; rewrite ifF 1:/# ifF 1:/# mulrDl /= /to_list !take_mkseq 1,2:/# /= /mkseq !(nth_map witness); 1,2: smt(size_iota).
rewrite !nth_iota 1,2:/# initiE 1:/# get8_set32_directE 1,2:/# /get8 !initiE 1,2:/# /= -/WArray960.get8 get8_set128_directE 1,2:/# /get8 initiE /#.
qed.

lemma poly_reduce_noloops :
Expand Down Expand Up @@ -1627,23 +1639,92 @@ seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />.
inline {1} 1; inline {2} 1.
swap {1} 3 -1;swap {2} [2..3] -1; swap {1} 7 -4; swap {2} 7 -4.
seq 3 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />.
wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\
0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\
wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ valid_ptr (to_uint rp{2}) (128 + 3 * 320) /\
0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ rp{2} = _ctp /\
j{1} *4 = i{1} * 5 /\ ={aa} /\
Glob.mem{2} = stores _mem (to_uint _ctp) (take j{1} (to_list rr0{1}))); last
by auto => />; smt(Array960.size_to_list List.take_size List.take0 storesE iota0).
unroll for {1} 2; unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ???????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt().
admit.
unroll for {1} 2; unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ?????????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt().
rewrite /storeW8 /=.
apply mem_eq_ext => adr.
rewrite !to_uintD_small /= 1..16:/# !addrA.
pose x1 := truncateu8
(truncateu16
((((zeroextu64 aa{2}.[to_uint i{2}] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `&`
(of_int 255)%W16).
pose x2 := truncateu8
((truncateu16
((((zeroextu64 aa{2}.[to_uint i{2} + 1] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `<<`
(of_int 2)%W8) `|`
(truncateu16
((((zeroextu64 aa{2}.[to_uint i{2}] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `>>`
(of_int 8)%W8)).
pose x3 := truncateu8
((truncateu16
((((zeroextu64 aa{2}.[to_uint i{2} + 2] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `<<`
(of_int 4)%W8) `|`
(truncateu16
((((zeroextu64 aa{2}.[to_uint i{2} + 1] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `>>`
(of_int 6)%W8)).
pose x4 := truncateu8
((truncateu16
((((zeroextu64 aa{2}.[to_uint i{2} + 3] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `<<`
(of_int 6)%W8) `|`
(truncateu16
((((zeroextu64 aa{2}.[to_uint i{2} + 2] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64) `>>`
(of_int 4)%W8)).
pose x5 := truncateu8
((((zeroextu64 aa{2}.[to_uint i{2} + 3] `<<` (of_int 10)%W8) + (of_int 1665)%W64) * (of_int 1290167)%W64 `>>`
(of_int 32)%W8) `&`
(of_int 1023)%W64 `>>` (of_int 2)%W8).
rewrite !get_storesE.
case (to_uint _ctp + to_uint j{2} <= adr < to_uint _ctp + to_uint j{2} + 5); last first.
case (to_uint _ctp <= adr < to_uint _ctp + to_uint j{2}); last first.
+ move => *;do 5!(rewrite get_set_neqE_s 1:/#).
rewrite !size_take 1:/# size_to_list /= ifF 1:/# get_storesE /= size_take 1:/# size_to_list /#.
+ move => *.
move => *;do 5!(rewrite get_set_neqE_s 1:/#).
rewrite !size_take 1:/# size_to_list /= ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /= get_storesE size_take 1:/# size_mkseq /= ifT 1:/#.
by rewrite nth_take 1,2:/# nth_mkseq 1:/# /=; smt(Array960.get_setE).
move => *. rewrite size_take 1:/# size_to_list ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /=.
by smt(Array960.get_setE get_set_neqE_s get_set_eqE_s).
qed.

lemma compress10_equiv_avx2i_dummy :
equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2_dummy :
={bp,ctp} ==> ={res} ] by proc;inline *;sim;auto => />.

lemma compress10_equiv_avx2i :
equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2 :
={bp} ==> ={res} ] by admit.

={bp} ==> ={res} ].
proc => /=.
swap {2} 2 -1.
seq 1 1 : #pre; 1: by sim.
inline *;wp.
while (={i,a,bp,b0,b1,b2,mask10,shift,sllv_indx,shuffle,aux} /\ 0<=i{1} <= 48 /\ aux{1}=48 /\ (forall k, 0<=k<i{1}*20 => rp{1}.[k] = rp{2}.[k]));
last by auto => /> *; split;[ smt() | move => *; rewrite tP => *;smt()].
auto => /> &1 &2 *;split;1:smt().
move => k kbl kbh; rewrite !initiE 1,2:/# /=.
rewrite !get8_set32_directE 1..4:/#.
case (0<=k<i{2}*20).
+ move => *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=.
rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifF 1,2:/#.
by rewrite /get8 !initiE /#.
move => *;case (i{2}*20<=k<i{2}*20+16).
+ move => *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=.
by rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifT /#.
by smt().
qed.

lemma compress10_equiv_refi :
equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig_i :
Expand Down Expand Up @@ -1781,7 +1862,7 @@ proc change ^while{3}.11 : (sliceset256_16_256 ap1 i2 a3). by admit.
proc change 26 : (init_768_16 (fun i => if 2 * 256 <= i < 3 * 256 then aux.[i - 2 * 256] else r.[i])). by auto.

proc change 30 : (init_960_8 (fun i_0 => ctp0.[i_0 + 0])). by done.
proc change 36 : (sliceget32_8_256 pvc_shufbidx_s 0). by admit.
proc change 37 : (sliceget32_8_256 pvc_shufbidx_s 0). by admit.

proc change ^while{4}.1 : (sliceget768_16_256 a i). by admit.

Expand All @@ -1793,11 +1874,12 @@ cfold 38.
unroll for 39.
cfold 38. unroll for 24. cfold 23.
unroll for 16. cfold 15. unroll for 8. cfold 7.

proc change 552 : (init_960_8 (fun i => ctp0.[i])). by done.
admit. (*
bdep 16 16 [_bp] [bp] [ap] lane pcond.

print get256_direct.
*)
qed.

(* MAP REDUCE GOAL *)
lemma compress10_mr :
Expand All @@ -1812,7 +1894,7 @@ lemma compress10_equiv :
proof.
proc* => /=.
exlim Glob.mem{1}, ctp{1} => _mem _ctp.
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(witness,bp); }
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(bp); }
(={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==>
Glob.mem{1} = stores _mem (to_uint _ctp) (to_list r{2}))
(lift_array768 bp{1} = lift_array768 bp{2} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> Glob.mem{2} = stores _mem (to_uint _ctp) (to_list r{1}));
Expand All @@ -1830,7 +1912,7 @@ lemma compress10_equiv_i :
proof.
proc* => /=.
exlim Glob.mem{1}, ctp{1} => _mem _ctp.
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(ctp,bp); }
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(bp); }
(={bp,ctp} ==> ={r})
(lift_array768 bp{1} = lift_array768 bp{2} /\ ={ctp} ==> ={r});
[ by smt() | by smt() | by call (compress10_equiv_avx2i); auto => /> |].
Expand All @@ -1840,7 +1922,7 @@ transitivity {2} { r <@ AuxPolyVecCompress10.ref(bp); }
[ by smt() | by smt() | | by call (compress10_equiv_refi); auto => />].
by call compress10_mr; auto => />.
qed.

import InnerPKE.
lemma mlkem_correct_enc_0_avx2 mem _ctp _pkp :
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_enc_0 ~ InnerPKE.enc_derand:
valid_ptr _pkp (384*3 + 32) /\
Expand Down

0 comments on commit 7e8ea02

Please sign in to comment.