Skip to content

Commit

Permalink
slicing madness
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Nov 2, 2024
1 parent 8e40414 commit e21494d
Showing 1 changed file with 76 additions and 20 deletions.
96 changes: 76 additions & 20 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1819,7 +1819,7 @@ op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = W256.bit

lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) :
0 <= offset =>
offset + csize < 16 * size l =>
offset + csize <= 16 * size l =>
0 <= bit < csize =>
nth false (take csize (drop offset (flatten (map W16.w2bits l)))) bit =
(nth witness l ((offset + bit) %/ 16)).[(offset + bit) %% 16].
Expand All @@ -1833,17 +1833,21 @@ rewrite -get_w2bits;congr.
by rewrite (nth_map witness) 1:/#.
qed.

lemma size_flatten_W16_w2bits (a : W16.t list) :
(size (flatten (map W16.w2bits (a)))) = 16 * size a.
proof.
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz count_predT /#.
qed.

lemma aligned_get256_16_256 arr offset :
0 <= offset < 16*256 - 256 =>
0 <= offset <= 16*256 - 256 =>
256 %| offset =>
sliceget256_16_256 arr offset =
WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 256).
move => Ho1 Ho2; rewrite /sliceget256_16_256.
have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256.
+ rewrite size_take 1:/# size_drop 1:/# /max /=.
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list).
have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256 by rewrite size_take 1:/# size_drop 1:/# /max /=;smt(Array256.size_to_list size_flatten_W16_w2bits).
rewrite wordP => i ib; rewrite get_bits2w //.
rewrite flatten_take_drop_16;1..3:smt(Array256.size_to_list).
rewrite nth_mkseq 1:/# /=.
Expand All @@ -1857,9 +1861,7 @@ bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget".
realize bvaslicegetP.
move => *; rewrite /sliceget256_16_256 bits2wK // size_take //= size_drop //=.
admit. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
rewrite size_flatten -map_comp /(\o) /=.
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list).
by smt(Array256.size_to_list size_flatten_W16_w2bits).
qed.

import BitEncoding BS2Int BitChunking.
Expand All @@ -1869,29 +1871,83 @@ op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t


lemma aligned_set256_16_256 arr offset bv :
0 <= offset < 16*256 - 256 =>
0 <= offset <= 16*256 - 256 =>
256 %| offset =>
sliceset256_16_256 arr offset bv =
Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 256) bv) i3).
rewrite /sliceset256_16_256 tP /= => ?? i ib.
rewrite !initiE 1,2:/# /=.
rewrite get16_set256E 1,2:/# /= (nth_map witness).
+ admit.
admitted.
rewrite get16_set256E 1,2:/# /= (nth_map []).
+ rewrite size_chunk // !size_cat !size_take 1:/# !size_drop 1:/# /max /=.
by smt(Array256.size_to_list size_flatten_W16_w2bits).
rewrite JWordList.nth_chunk //= 1:/#.
rewrite !size_cat !size_take 1:/# !size_drop 1:/# /max /=.
by smt(Array256.size_to_list size_flatten_W16_w2bits).
case (32 * (offset %/ 256) <= 2 * i);last first.
+ move => ? /=. have ? : 16*i < offset. smt().
rewrite get16_init16 1:/# -catA drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite take_cat_le ifT;1: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
have -> : offset = 16 * (offset %/ 16) by smt().
rewrite take_flatten_ctt; 1: by smt(mapP W16.size_w2bits).
rewrite -map_take.
rewrite -(W16.w2bitsK arr.[i]);congr.
apply (eq_from_nth false).
+ rewrite size_w2bits size_take // size_drop 1:/# /= /max /=;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
move => k kb; rewrite flatten_take_drop_16 1:/#.
+ rewrite size_take 1:/# size_to_list //= 1:/#.
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite nth_take 1:/#. smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
case (2 * i < 32 * (offset %/ 256 + 1));last first.
+ move => ? /=. have ? : offset + 256 <= 16*i . smt().
rewrite get16_init16 1:/# -catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#.
have -> : offset + 256 = 16 * ((offset + 256) %/ 16) by smt().
rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits).
have -> : 16 * i - offset - 256 = 16 * (i - offset %/ 16 - 16) by smt().
rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits mem_drop).
rewrite drop_drop 1,2:/# /= => ?.
rewrite -(W16.w2bitsK arr.[i]);congr.
apply (eq_from_nth false).
+ rewrite -map_drop size_take // size_flatten_W16_w2bits size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits).
move => k kb.
have -> : i - offset %/ 16 - 16 + (offset + 256) %/ 16 = i by smt().
rewrite -(drop_flatten_ctt 16); 1: smt(mapP W16.size_w2bits).
rewrite flatten_take_drop_16; 1..3: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).

+ move => ?? /=. have ? : offset <= 16*i < offset + 256. smt().
rewrite -!catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite !drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#.
rewrite take_cat_le ifT;1: by rewrite size_drop 1:/# size_w2bits /= /max ifT /#.
rewrite -(W16.w2bitsK ((bv \bits16 i - 16 * (offset %/ 256))));congr.
apply (eq_from_nth false).
+ rewrite size_take // size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits).
move => k kb.
rewrite nth_take; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
qed.



bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset".
realize bvaslicesetP.
move => arr offset bv *.
realize bvaslicesetP. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
move => arr offset bv *. have ? : 0 <= offset by admit.
rewrite /sliceset256_16_256 of_listK.
+ admit.
+ rewrite size_map size_chunk // !size_cat size_take 1:/#.
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
rewrite -(map_comp W16.w2bits W16.bits2w) /(\o).
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16
(take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++
drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))).
rewrite iffE => [#] -> *.
+ admit.
rewrite iffE => [#] -> * /=.
+ by smt(in_chunk_size W16.bits2wK).
rewrite map_id /= chunkK //.
+ admit.
+ rewrite !size_cat size_take 1:/#.
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
qed.

op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).
Expand Down

0 comments on commit e21494d

Please sign in to comment.