Skip to content

Commit e21494d

Browse files
slicing madness
1 parent 8e40414 commit e21494d

File tree

1 file changed

+76
-20
lines changed

1 file changed

+76
-20
lines changed

proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,7 +1819,7 @@ op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = W256.bit
18191819

18201820
lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) :
18211821
0 <= offset =>
1822-
offset + csize < 16 * size l =>
1822+
offset + csize <= 16 * size l =>
18231823
0 <= bit < csize =>
18241824
nth false (take csize (drop offset (flatten (map W16.w2bits l)))) bit =
18251825
(nth witness l ((offset + bit) %/ 16)).[(offset + bit) %% 16].
@@ -1833,17 +1833,21 @@ rewrite -get_w2bits;congr.
18331833
by rewrite (nth_map witness) 1:/#.
18341834
qed.
18351835

1836+
lemma size_flatten_W16_w2bits (a : W16.t list) :
1837+
(size (flatten (map W16.w2bits (a)))) = 16 * size a.
1838+
proof.
1839+
rewrite size_flatten -map_comp /(\o) /=.
1840+
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
1841+
rewrite StdBigop.Bigint.big_constz count_predT /#.
1842+
qed.
1843+
18361844
lemma aligned_get256_16_256 arr offset :
1837-
0 <= offset < 16*256 - 256 =>
1845+
0 <= offset <= 16*256 - 256 =>
18381846
256 %| offset =>
18391847
sliceget256_16_256 arr offset =
18401848
WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 256).
18411849
move => Ho1 Ho2; rewrite /sliceget256_16_256.
1842-
have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256.
1843-
+ rewrite size_take 1:/# size_drop 1:/# /max /=.
1844-
rewrite size_flatten -map_comp /(\o) /=.
1845-
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
1846-
rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list).
1850+
have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256 by rewrite size_take 1:/# size_drop 1:/# /max /=;smt(Array256.size_to_list size_flatten_W16_w2bits).
18471851
rewrite wordP => i ib; rewrite get_bits2w //.
18481852
rewrite flatten_take_drop_16;1..3:smt(Array256.size_to_list).
18491853
rewrite nth_mkseq 1:/# /=.
@@ -1857,9 +1861,7 @@ bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget".
18571861
realize bvaslicegetP.
18581862
move => *; rewrite /sliceget256_16_256 bits2wK // size_take //= size_drop //=.
18591863
admit. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
1860-
rewrite size_flatten -map_comp /(\o) /=.
1861-
rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
1862-
by rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list).
1864+
by smt(Array256.size_to_list size_flatten_W16_w2bits).
18631865
qed.
18641866

18651867
import BitEncoding BS2Int BitChunking.
@@ -1869,29 +1871,83 @@ op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t
18691871

18701872

18711873
lemma aligned_set256_16_256 arr offset bv :
1872-
0 <= offset < 16*256 - 256 =>
1874+
0 <= offset <= 16*256 - 256 =>
18731875
256 %| offset =>
18741876
sliceset256_16_256 arr offset bv =
18751877
Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 256) bv) i3).
18761878
rewrite /sliceset256_16_256 tP /= => ?? i ib.
18771879
rewrite !initiE 1,2:/# /=.
1878-
rewrite get16_set256E 1,2:/# /= (nth_map witness).
1879-
+ admit.
1880-
admitted.
1880+
rewrite get16_set256E 1,2:/# /= (nth_map []).
1881+
+ rewrite size_chunk // !size_cat !size_take 1:/# !size_drop 1:/# /max /=.
1882+
by smt(Array256.size_to_list size_flatten_W16_w2bits).
1883+
rewrite JWordList.nth_chunk //= 1:/#.
1884+
rewrite !size_cat !size_take 1:/# !size_drop 1:/# /max /=.
1885+
by smt(Array256.size_to_list size_flatten_W16_w2bits).
1886+
case (32 * (offset %/ 256) <= 2 * i);last first.
1887+
+ move => ? /=. have ? : 16*i < offset. smt().
1888+
rewrite get16_init16 1:/# -catA drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1889+
rewrite take_cat_le ifT;1: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1890+
have -> : offset = 16 * (offset %/ 16) by smt().
1891+
rewrite take_flatten_ctt; 1: by smt(mapP W16.size_w2bits).
1892+
rewrite -map_take.
1893+
rewrite -(W16.w2bitsK arr.[i]);congr.
1894+
apply (eq_from_nth false).
1895+
+ rewrite size_w2bits size_take // size_drop 1:/# /= /max /=;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1896+
move => k kb; rewrite flatten_take_drop_16 1:/#.
1897+
+ rewrite size_take 1:/# size_to_list //= 1:/#.
1898+
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1899+
rewrite nth_take 1:/#. smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1900+
rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1901+
case (2 * i < 32 * (offset %/ 256 + 1));last first.
1902+
+ move => ? /=. have ? : offset + 256 <= 16*i . smt().
1903+
rewrite get16_init16 1:/# -catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1904+
rewrite drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1905+
rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#.
1906+
have -> : offset + 256 = 16 * ((offset + 256) %/ 16) by smt().
1907+
rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits).
1908+
have -> : 16 * i - offset - 256 = 16 * (i - offset %/ 16 - 16) by smt().
1909+
rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits mem_drop).
1910+
rewrite drop_drop 1,2:/# /= => ?.
1911+
rewrite -(W16.w2bitsK arr.[i]);congr.
1912+
apply (eq_from_nth false).
1913+
+ rewrite -map_drop size_take // size_flatten_W16_w2bits size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits).
1914+
move => k kb.
1915+
have -> : i - offset %/ 16 - 16 + (offset + 256) %/ 16 = i by smt().
1916+
rewrite -(drop_flatten_ctt 16); 1: smt(mapP W16.size_w2bits).
1917+
rewrite flatten_take_drop_16; 1..3: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1918+
rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1919+
1920+
+ move => ?? /=. have ? : offset <= 16*i < offset + 256. smt().
1921+
rewrite -!catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1922+
rewrite !drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1923+
rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#.
1924+
rewrite take_cat_le ifT;1: by rewrite size_drop 1:/# size_w2bits /= /max ifT /#.
1925+
rewrite -(W16.w2bitsK ((bv \bits16 i - 16 * (offset %/ 256))));congr.
1926+
apply (eq_from_nth false).
1927+
+ rewrite size_take // size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits).
1928+
move => k kb.
1929+
rewrite nth_take; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1930+
rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1931+
rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
1932+
qed.
1933+
1934+
18811935

18821936
bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset".
1883-
realize bvaslicesetP.
1884-
move => arr offset bv *.
1937+
realize bvaslicesetP. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *)
1938+
move => arr offset bv *. have ? : 0 <= offset by admit.
18851939
rewrite /sliceset256_16_256 of_listK.
1886-
+ admit.
1940+
+ rewrite size_map size_chunk // !size_cat size_take 1:/#.
1941+
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
18871942
rewrite -(map_comp W16.w2bits W16.bits2w) /(\o).
18881943
have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16
18891944
(take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++
18901945
drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))).
1891-
rewrite iffE => [#] -> *.
1892-
+ admit.
1946+
rewrite iffE => [#] -> * /=.
1947+
+ by smt(in_chunk_size W16.bits2wK).
18931948
rewrite map_id /= chunkK //.
1894-
+ admit.
1949+
+ rewrite !size_cat size_take 1:/#.
1950+
by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0).
18951951
qed.
18961952

18971953
op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).

0 commit comments

Comments
 (0)