@@ -27,10 +27,10 @@ import MLKEM_PolyAVXVec.
27
27
import NTT_Avx2.
28
28
import WArray136 WArray32 WArray128.
29
29
import WArray512 WArray256.
30
- (*
30
+
31
31
(* shake assumptions *)
32
32
33
- (*
33
+
34
34
op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t.
35
35
op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t.
36
36
@@ -1264,10 +1264,10 @@ do split.
1264
1264
+ smt().
1265
1265
by smt(unpackvK).
1266
1266
qed.
1267
- *)
1267
+
1268
1268
(***************************************************)
1269
1269
1270
- *)
1270
+
1271
1271
import WArray960 WArray1536 Array4.
1272
1272
1273
1273
module AuxPolyVecCompress10 = {
@@ -1336,13 +1336,6 @@ rp <-
1336
1336
return rp;
1337
1337
}
1338
1338
1339
- proc avx2_dummy (ctp : W8.t Array1088.t , bp : W16.t Array768.t ) : W8.t Array960.t = {
1340
- var rr : W8.t Array960.t ;
1341
- bp <@ Jkem_avx2.M (Jkem_avx2.Syscall ).__polyvec_reduce_sig (bp);
1342
- rr <@ __polyvec_compress_avx2 (ctp,bp);
1343
- return rr;
1344
- }
1345
-
1346
1339
proc avx2 (bp : W16.t Array768.t ) : W8.t Array960.t = {
1347
1340
var rr : W8.t Array960.t ;
1348
1341
var ctp : W8.t Array1088.t <- (init (fun (i_0 : int ) => W8.zero ))%Array1088;
@@ -1579,13 +1572,32 @@ swap {2} 2 -1;seq 1 1 : #pre; 1: by conseq />;inline *;sim.
1579
1572
inline {1 } 1 ; inline {2 } 2 .
1580
1573
wp.
1581
1574
while (Glob.mem{1 } = stores _mem (to_uint _ctp) (take (i{2 }*20 ) (to_list rp{2 })) /\ aux{1 } = 48 /\
1575
+ valid_ptr (to_uint r{1 }) (128 + 3 * 320 ) /\ r{1 } = _ctp /\
1582
1576
={i,a,aux,sllv_indx, shuffle, shift, mask10, b2, b1, b0} /\ 0 <= i{2 } <= 48 ); last
1583
1577
by auto => />;smt(Array960.size_to_list List.take_size List.take0 storesE iota0).
1584
1578
1585
1579
seq 3 3 : (#pre /\ ={lo,hi});
1586
1580
1 : by conseq />; sim.
1587
- auto => /> &1 &2 ???;split;last by smt ().
1588
- admit.
1581
+ auto => /> &1 &2 ????;split;last by smt ().
1582
+ rewrite /storeW32 /storeW128.
1583
+ apply mem_eq_ext => add.
1584
+ rewrite !get_storesE !to_uintD_small /= !of_uintK /= 1,2:/# !modz_small 1 ..2 :/#.
1585
+ rewrite !size_take 1 ,2 :/# /= !size_to_list.
1586
+ case ((to_uint _ctp <= add && add < to_uint _ctp + MIN ((i{1 } + 1 ) * 20 ) 960 )); last by smt ().
1587
+ move => *.
1588
+ case ((to_uint _ctp + MIN (i{1 } * 20 ) 960) <= add && add < to_uint _ctp + MIN (i{1 } * 20 + 16 ) 960).
1589
+ + move => *; rewrite ifF 1 :/# ifT 1 :/# mulrDl /= takeD 1 ,2 :/# nth_cat !size_take 1 :/# size_to_list .
1590
+ have -> /= : add - to_uint _ctp < MIN (i{1 } * 20 ) 960 = false by smt ().
1591
+ rewrite /to_list drop_mkseq 1:/# take_mkseq 1 :/# /= /(\o) /= /mkseq (nth_map witness) /=;1 :smt(size_iota).
1592
+ rewrite nth_iota 1 :/# initiE 1 :/# get8_set32_directE 1 ,2 :/# /= /get8 initiE 1 :/# /= -/WArray960.get8 initiE 1 :/# get8_set128_directE /#.
1593
+ case ((to_uint _ctp + MIN (i{1 } * 20 +16 ) 960 ) <= add && add < to_uint _ctp + MIN (i{1 } * 20 + 20 ) 960 ).
1594
+ + move => *; rewrite ifT 1 :/# mulrDl /= takeD 1 ,2 :/# nth_cat !size_take 1 :/# size_to_list .
1595
+ have -> /= : add - to_uint _ctp < MIN (i{1 } * 20 ) 960 = false by smt ().
1596
+ rewrite /to_list drop_mkseq 1:/# take_mkseq 1 :/# /= /(\o) /= /mkseq (nth_map witness) /=;1 :smt(size_iota).
1597
+ rewrite nth_iota 1 :/# initiE 1 :/# get8_set32_directE 1 ,2 :/# /= /get8 initiE 1 :/# /= -/WArray960.get8 initiE 1 :/# get8_set128_directE /#.
1598
+ case (to_uint _ctp <= add && add < to_uint _ctp + MIN (i{1 } * 20 ) 960 ); last by smt ().
1599
+ move => *; rewrite ifF 1 :/# ifF 1 :/# mulrDl /= /to_list !take_mkseq 1 ,2 :/# /= /mkseq !(nth_map witness); 1 ,2 : smt(size_iota).
1600
+ rewrite !nth_iota 1 ,2 :/# initiE 1 :/# get8_set32_directE 1 ,2 :/# /get8 !initiE 1 ,2 :/# /= -/WArray960.get8 get8_set128_directE 1 ,2 :/# /get8 initiE /#.
1589
1601
qed.
1590
1602
1591
1603
lemma poly_reduce_noloops :
@@ -1627,23 +1639,92 @@ seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />.
1627
1639
inline {1 } 1 ; inline {2 } 1 .
1628
1640
swap {1 } 3 -1 ;swap {2 } [2 ..3 ] -1 ; swap {1 } 7 -4 ; swap {2 } 7 -4 .
1629
1641
seq 3 3 : (#pre /\ ={aa});1 :by call polyvec_csubq_noloops;auto => />.
1630
- wp;while (0 <= i{1 } <= 768 /\ i{1 } = to_uint i{2 } /\
1631
- 0 <= j{1 } <= 960 /\ j{1 } = to_uint j{2 } /\
1642
+ wp;while (0 <= i{1 } <= 768 /\ i{1 } = to_uint i{2 } /\ valid_ptr (to_uint rp{ 2 }) ( 128 + 3 * 320 ) /\
1643
+ 0 <= j{1 } <= 960 /\ j{1 } = to_uint j{2 } /\ rp{ 2 } = _ctp /\
1632
1644
j{1 } *4 = i{1 } * 5 /\ ={aa} /\
1633
1645
Glob.mem{2 } = stores _mem (to_uint _ctp) (take j{1 } (to_list rr0{1 }))); last
1634
1646
by auto => />; smt(Array960.size_to_list List.take_size List.take0 storesE iota0).
1635
- unroll for {1 } 2 ; unroll for {2 } 2 ;auto => /> &1 &2 ;rewrite !ultE /= => ???????;do split;1 ,2 ,4 ,5 :smt();1 ..3 ,5 ..:by rewrite ?to_uintD_small;smt().
1636
- admit.
1647
+ unroll for {1 } 2 ; unroll for {2 } 2 ;auto => /> &1 &2 ;rewrite !ultE /= => ?????????;do split;1 ,2 ,4 ,5 :smt();1 ..3 ,5 ..:by rewrite ?to_uintD_small;smt().
1648
+ rewrite /storeW8 /=.
1649
+ apply mem_eq_ext => adr.
1650
+ rewrite !to_uintD_small /= 1 ..16 :/# !addrA.
1651
+ pose x1 := truncateu8
1652
+ (truncateu16
1653
+ ((((zeroextu64 aa{2 }.[to_uint i{2 }] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1654
+ (of_int 32 )%W8) `&`
1655
+ (of_int 1023 )%W64) `&`
1656
+ (of_int 255 )%W16).
1657
+ pose x2 := truncateu8
1658
+ ((truncateu16
1659
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 1 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1660
+ (of_int 32 )%W8) `&`
1661
+ (of_int 1023 )%W64) `<<`
1662
+ (of_int 2 )%W8) `|`
1663
+ (truncateu16
1664
+ ((((zeroextu64 aa{2 }.[to_uint i{2 }] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1665
+ (of_int 32 )%W8) `&`
1666
+ (of_int 1023 )%W64) `>>`
1667
+ (of_int 8 )%W8)).
1668
+ pose x3 := truncateu8
1669
+ ((truncateu16
1670
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 2 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1671
+ (of_int 32 )%W8) `&`
1672
+ (of_int 1023 )%W64) `<<`
1673
+ (of_int 4 )%W8) `|`
1674
+ (truncateu16
1675
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 1 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1676
+ (of_int 32 )%W8) `&`
1677
+ (of_int 1023 )%W64) `>>`
1678
+ (of_int 6 )%W8)).
1679
+ pose x4 := truncateu8
1680
+ ((truncateu16
1681
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 3 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1682
+ (of_int 32 )%W8) `&`
1683
+ (of_int 1023 )%W64) `<<`
1684
+ (of_int 6 )%W8) `|`
1685
+ (truncateu16
1686
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 2 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1687
+ (of_int 32 )%W8) `&`
1688
+ (of_int 1023 )%W64) `>>`
1689
+ (of_int 4 )%W8)).
1690
+ pose x5 := truncateu8
1691
+ ((((zeroextu64 aa{2 }.[to_uint i{2 } + 3 ] `<<` (of_int 10 )%W8) + (of_int 1665 )%W64) * (of_int 1290167 )%W64 `>>`
1692
+ (of_int 32 )%W8) `&`
1693
+ (of_int 1023 )%W64 `>>` (of_int 2 )%W8).
1694
+ rewrite !get_storesE.
1695
+ case (to_uint _ctp + to_uint j{2 } <= adr < to_uint _ctp + to_uint j{2 } + 5 ); last first.
1696
+ case (to_uint _ctp <= adr < to_uint _ctp + to_uint j{2 }); last first.
1697
+ + move => *;do 5 !(rewrite get_set_neqE_s 1 :/#).
1698
+ rewrite !size_take 1 :/# size_to_list /= ifF 1 :/# get_storesE /= size_take 1 :/# size_to_list /#.
1699
+ + move => *.
1700
+ move => *;do 5 !(rewrite get_set_neqE_s 1 :/#).
1701
+ rewrite !size_take 1 :/# size_to_list /= ifT 1 :/# nth_take 1 ,2 :/# /to_list nth_mkseq 1 :/# /= get_storesE size_take 1 :/# size_mkseq /= ifT 1 :/#.
1702
+ by rewrite nth_take 1 ,2 :/# nth_mkseq 1 :/# /=; smt(Array960.get_setE).
1703
+ move => *. rewrite size_take 1 :/# size_to_list ifT 1 :/# nth_take 1 ,2 :/# /to_list nth_mkseq 1 :/# /=.
1704
+ by smt (Array960.get_setE get_set_neqE_s get_set_eqE_s).
1637
1705
qed.
1638
1706
1639
- lemma compress10_equiv_avx2i_dummy :
1640
- equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2_dummy :
1641
- ={bp,ctp} ==> ={res} ] by proc;inline *;sim;auto => />.
1642
-
1643
1707
lemma compress10_equiv_avx2i :
1644
1708
equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2 :
1645
- ={bp} ==> ={res} ] by admit.
1646
-
1709
+ ={bp} ==> ={res} ].
1710
+ proc => /=.
1711
+ swap {2 } 2 -1 .
1712
+ seq 1 1 : #pre; 1 : by sim.
1713
+ inline *;wp.
1714
+ while (={i,a,bp,b0,b1,b2,mask10,shift,sllv_indx,shuffle,aux} /\ 0 <=i{1 } <= 48 /\ aux{1 }=48 /\ (forall k, 0 <=k<i{1 }*20 => rp{1 }.[k] = rp{2 }.[k]));
1715
+ last by auto => /> *; split;[ smt() | move => *; rewrite tP => *;smt()].
1716
+ auto => /> &1 &2 *;split;1 :smt().
1717
+ move => k kbl kbh; rewrite !initiE 1 ,2 :/# /=.
1718
+ rewrite !get8_set32_directE 1 ..4 :/#.
1719
+ case (0 <=k<i{2 }*20 ).
1720
+ + move => *; rewrite !ifF 1 ,2 :/# /get8 !initiE 1 ..4 :/# /=.
1721
+ rewrite -/WArray960.get8 !get8_set128_directE 1 ..4 :/# !ifF 1 ,2 :/#.
1722
+ by rewrite /get8 !initiE /#.
1723
+ move => *;case (i{2 }*20 <=k<i{2 }*20 +16 ).
1724
+ + move => *; rewrite !ifF 1 ,2 :/# /get8 !initiE 1 ..4 :/# /=.
1725
+ by rewrite -/WArray960.get8 !get8_set128_directE 1 ..4 :/# !ifT /#.
1726
+ by smt ().
1727
+ qed.
1647
1728
1648
1729
lemma compress10_equiv_refi :
1649
1730
equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig_i :
@@ -1781,7 +1862,7 @@ proc change ^while{3}.11 : (sliceset256_16_256 ap1 i2 a3). by admit.
1781
1862
proc change 26 : (init_768_16 (fun i => if 2 * 256 <= i < 3 * 256 then aux.[i - 2 * 256 ] else r.[i])). by auto .
1782
1863
1783
1864
proc change 30 : (init_960_8 (fun i_0 => ctp0.[i_0 + 0 ])). by done.
1784
- proc change 36 : (sliceget32_8_256 pvc_shufbidx_s 0 ). by admit.
1865
+ proc change 37 : (sliceget32_8_256 pvc_shufbidx_s 0 ). by admit.
1785
1866
1786
1867
proc change ^while {4 }.1 : (sliceget768_16_256 a i). by admit.
1787
1868
@@ -1793,11 +1874,12 @@ cfold 38.
1793
1874
unroll for 39.
1794
1875
cfold 38. unroll for 24. cfold 23.
1795
1876
unroll for 16. cfold 15. unroll for 8. cfold 7.
1796
-
1797
- proc change 552 : (init_960_8 (fun i => ctp0.[i])). by done.
1877
+ admit. (*
1798
1878
bdep 16 16 [_bp] [bp] [ap] lane pcond.
1799
1879
1800
1880
print get256_direct.
1881
+ *)
1882
+ qed.
1801
1883
1802
1884
(* MAP REDUCE GOAL *)
1803
1885
lemma compress10_mr :
@@ -1812,7 +1894,7 @@ lemma compress10_equiv :
1812
1894
proof.
1813
1895
proc* => /=.
1814
1896
exlim Glob.mem{1 }, ctp{1 } => _mem _ctp.
1815
- transitivity {1 } { r <@ AuxPolyVecCompress10.avx2 (witness, bp); }
1897
+ transitivity {1 } { r <@ AuxPolyVecCompress10.avx2 (bp); }
1816
1898
(={bp} /\ ctp{1 } = _ctp /\ Glob.mem{1 } = _mem /\ valid_ptr (to_uint ctp{1 }) (128 + 3 * 320 ) ==>
1817
1899
Glob.mem{1 } = stores _mem (to_uint _ctp) (to_list r{2 }))
1818
1900
(lift_array768 bp{1 } = lift_array768 bp{2 } /\ ctp{2 } = _ctp /\ Glob.mem{2 } = _mem /\ valid_ptr (to_uint ctp{2 }) (128 + 3 * 320 ) ==> Glob.mem{2 } = stores _mem (to_uint _ctp) (to_list r{1 }));
@@ -1830,7 +1912,7 @@ lemma compress10_equiv_i :
1830
1912
proof.
1831
1913
proc* => /=.
1832
1914
exlim Glob.mem{1 }, ctp{1 } => _mem _ctp.
1833
- transitivity {1 } { r <@ AuxPolyVecCompress10.avx2 (ctp, bp); }
1915
+ transitivity {1 } { r <@ AuxPolyVecCompress10.avx2 (bp); }
1834
1916
(={bp,ctp} ==> ={r})
1835
1917
(lift_array768 bp{1 } = lift_array768 bp{2 } /\ ={ctp} ==> ={r});
1836
1918
[ by smt () | by smt() | by call (compress10_equiv_avx2i); auto => /> |].
@@ -1840,7 +1922,7 @@ transitivity {2} { r <@ AuxPolyVecCompress10.ref(bp); }
1840
1922
[ by smt () | by smt() | | by call (compress10_equiv_refi); auto => />].
1841
1923
by call compress10_mr; auto => />.
1842
1924
qed.
1843
-
1925
+ import InnerPKE.
1844
1926
lemma mlkem_correct_enc_0_avx2 mem _ctp _pkp :
1845
1927
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_enc_0 ~ InnerPKE.enc_derand:
1846
1928
valid_ptr _pkp (384 *3 + 32 ) /\
0 commit comments