@@ -4,89 +4,170 @@ import (
44 "github.com/consensys/linea-monorepo/prover/crypto/mimc"
55 "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
66 "github.com/consensys/linea-monorepo/prover/maths/field"
7+ "github.com/consensys/linea-monorepo/prover/protocol/ifaces"
78 "github.com/consensys/linea-monorepo/prover/protocol/wizard"
89 "github.com/consensys/linea-monorepo/prover/utils/parallel"
910)
1011
11- // assign assigns the columns to the prover runtime
12+ // assign: Assigns the columns to the prover runtime using PaddedCircularWindow
1213func (ctx * mimcCtx ) assign (run * wizard.ProverRuntime ) {
1314
1415 var (
15- oldState = ctx .oldStates .GetColAssignment (run ).IntoRegVecSaveAlloc ()
16- blocks = ctx .blocks .GetColAssignment (run ).IntoRegVecSaveAlloc ()
17-
18- // Initialize slices to hold intermediate results and intermediatePow4
19- // The first entry is left empty for consistency with ctx.intermediateResult
20- // We don't need to assign it because it is assigned already.
21- intermediateRes = make ([][]field.Element , len (ctx .intermediateResult ))
22- intermediatePow4 = make ([][]field.Element , len (ctx .intermediateResult ))
16+ oldStateSV = ctx .oldStates .GetColAssignment (run )
17+ blocksSV = ctx .blocks .GetColAssignment (run )
18+ totalRows = oldStateSV .Len ()
19+ numRounds = len (ctx .intermediateResult )
2320 )
2421
25- // Initialize intermediateRes and intermediatePow4 with correct lengths
26- for i := range intermediateRes {
27- // For each intermediate result, create a slice of field.Elements with length numRows
28- intermediateRes [i ] = make ([]field.Element , len (oldState ))
29- intermediatePow4 [i ] = make ([]field.Element , len (oldState ))
30- }
22+ offset , windowLen := identifyActiveWindow (oldStateSV , blocksSV , totalRows )
23+ oldStateWindow , blocksWindow := extractWindowSlices (oldStateSV , blocksSV , offset , windowLen )
24+ intermediateResWindow , intermediatePow4Window := computeIntermediateValues (numRounds , oldStateWindow , blocksWindow , windowLen )
3125
32- // Set the initial intermediate res as the block itself
33- intermediateRes [0 ] = blocks
26+ var resPad , pow4Pad []field.Element
3427
35- // Compute intermediate values for each round
36- for i := range ctx .intermediateResult {
37- computeIntermediateValues (i , oldState , intermediateRes , intermediatePow4 )
28+ // Precompute padding only when `PaddedCircularWindow` is tobe used
29+ // i.e. Whenever there is sparsity in the (oldState, blocks) pair
30+ if windowLen != totalRows {
31+ resPad , pow4Pad = precomputePaddingValues (numRounds )
3832 }
33+ assignOptimizedVectors (run , ctx , intermediateResWindow , intermediatePow4Window , resPad , pow4Pad , offset , totalRows )
34+ }
35+
36+ // identifyActiveWindow finds the smallest active window scanning through the oldState and blocks
37+ func identifyActiveWindow (oldStateSV , blocksSV smartvectors.SmartVector , totalRows int ) (offset int , windowLen int ) {
38+ // Convert to regular vectors to scan all elements
39+ var (
40+ oldState = smartvectors .IntoRegVec (oldStateSV )
41+ blocks = smartvectors .IntoRegVec (blocksSV )
42+ )
3943
40- // Assign columns
41- for i := range ctx .intermediateResult {
42- // Assign computed values to the runtime
43- if i > 0 {
44- // Skip the first intermediate result
45- // Recall that the first intermediate res is the block itself
46- run .AssignColumn (
47- ctx .intermediateResult [i ].GetColID (),
48- smartvectors .NewRegular (intermediateRes [i ]),
49- )
44+ // Initialize firstNonZero and lastNonZero indices to default values
45+ firstNonZero , lastNonZero := totalRows , - 1
46+ for i := 0 ; i < totalRows ; i ++ {
47+ if ! oldState [i ].IsZero () || ! blocks [i ].IsZero () {
48+ firstNonZero = min (firstNonZero , i )
49+ lastNonZero = max (lastNonZero , i )
5050 }
51+ }
52+
53+ if firstNonZero <= lastNonZero {
54+ offset = firstNonZero
55+ windowLen = lastNonZero - firstNonZero + 1
56+ return offset , windowLen
57+ }
58+ // Default window => Full window
59+ return 0 , totalRows
60+ }
5161
52- // Assign intermediatePow4 to the runtime
53- run .AssignColumn (
54- ctx .intermediatePow4 [i ].GetColID (),
55- smartvectors .NewRegular (intermediatePow4 [i ]),
56- )
62+ // computeIntermediateValues computes intermediate values for the window
63+ func computeIntermediateValues (numRounds int , oldStateWindow , blocksWindow []field.Element , windowLen int ) ([][]field.Element , [][]field.Element ) {
64+ intermediateResWindow := make ([][]field.Element , numRounds )
65+ intermediatePow4Window := make ([][]field.Element , numRounds )
66+ for i := range intermediateResWindow {
67+ intermediateResWindow [i ] = make ([]field.Element , windowLen )
68+ intermediatePow4Window [i ] = make ([]field.Element , windowLen )
5769 }
5870
71+ // Initalize intermediateResWindow to the blocksWindow
72+ copy (intermediateResWindow [0 ], blocksWindow )
73+
74+ // r => round
75+ for r := 0 ; r < numRounds ; r ++ {
76+ parallel .Execute (windowLen , func (start , stop int ) {
77+ for k := start ; k < stop ; k ++ {
78+ if r == 0 {
79+ tmp := intermediateResWindow [0 ][k ]
80+ tmp .Add (& tmp , & mimc .Constants [0 ]).Add (& tmp , & oldStateWindow [k ])
81+ intermediatePow4Window [0 ][k ].Square (& tmp ).Square (& intermediatePow4Window [0 ][k ])
82+ } else {
83+ // For subsequent rounds, compute intermediate values based on previous results
84+ ark := mimc .Constants [r - 1 ]
85+ nextArk := mimc .Constants [r ]
86+
87+ tmp := intermediatePow4Window [r - 1 ][k ]
88+ tmp .Square (& tmp ).Square (& tmp )
89+
90+ // Compute intermediate result using previous result and oldState
91+ intermediateResWindow [r ][k ] = intermediateResWindow [r - 1 ][k ]
92+ intermediateResWindow [r ][k ].Add (& intermediateResWindow [r ][k ], & ark ).Add (& intermediateResWindow [r ][k ], & oldStateWindow [k ])
93+ intermediateResWindow [r ][k ].Mul (& intermediateResWindow [r ][k ], & tmp )
94+
95+ // Compute intermediatePow4
96+ tmp = intermediateResWindow [r ][k ]
97+ tmp .Add (& tmp , & nextArk ).Add (& tmp , & oldStateWindow [k ])
98+ intermediatePow4Window [r ][k ].Square (& tmp ).Square (& intermediatePow4Window [r ][k ])
99+ }
100+ }
101+ })
102+ }
103+ return intermediateResWindow , intermediatePow4Window
59104}
60105
61- // computeIntermediateValues computes intermediate values for the given round
62- func computeIntermediateValues (round int , oldState []field.Element , intermediateRes , intermediatePow4 [][]field.Element ) {
63- parallel .Execute (len (oldState ), func (start , stop int ) {
64- for k := start ; k < stop ; k ++ {
65- if round == 0 {
66- // For the first round, compute initial intermediatePow4
67- tmp := intermediateRes [0 ][k ]
68- tmp .Add (& tmp , & mimc .Constants [0 ]).Add (& tmp , & oldState [k ])
69- intermediatePow4 [0 ][k ].Square (& tmp ).Square (& intermediatePow4 [0 ][k ])
106+ // assignOptimizedVectors assigns optimized vectors to the prover runtime
107+ func assignOptimizedVectors (run * wizard.ProverRuntime , ctx * mimcCtx , intermediateResWindow , intermediatePow4Window [][]field.Element , resPad , pow4Pad []field.Element , offset , totalRows int ) {
108+ for round := range ctx .intermediateResult {
109+ windowLen := len (intermediateResWindow [round ])
110+
111+ // Full-length window: use Regular vector
112+ isRegSmartVec := windowLen == totalRows
113+
114+ // Helper function to assign a column with the appropriate smart vector
115+ assignColumn := func (colID ifaces.ColID , window []field.Element , padVal field.Element ) {
116+ if isRegSmartVec {
117+ fullVec := make ([]field.Element , totalRows )
118+ copy (fullVec [offset :offset + windowLen ], window )
119+ run .AssignColumn (colID , smartvectors .NewRegular (fullVec ))
70120 } else {
71- // For subsequent rounds, compute intermediate values based on previous results
72- ark := mimc .Constants [round - 1 ]
73- nextArk := mimc .Constants [round ]
74-
75- tmp := intermediatePow4 [round - 1 ][k ]
76- tmp .Square (& tmp ).Square (& tmp )
77-
78- // Compute intermediate result using previous result and oldState
79- intermediateRes [round ][k ] = intermediateRes [round - 1 ][k ]
80- intermediateRes [round ][k ].Add (& intermediateRes [round ][k ], & ark ).Add (& intermediateRes [round ][k ], & oldState [k ])
81- intermediateRes [round ][k ].Mul (& intermediateRes [round ][k ], & tmp )
82-
83- // Compute intermediatePow4
84- tmp = intermediateRes [round ][k ]
85- tmp .Add (& tmp , & nextArk ).Add (& tmp , & oldState [k ])
86- tmp .Square (& tmp ).Square (& tmp )
87- intermediatePow4 [round ][k ] = tmp
121+ // Partial window: use PaddedCircularWindow with lazily evaluated padding
122+ run .AssignColumn (colID , smartvectors .NewPaddedCircularWindow (window , padVal , offset , totalRows ))
88123 }
89124 }
90- })
91125
126+ // Determine padding values
127+ var resPadVal , pow4PadVal field.Element
128+ if resPad != nil && len (resPad ) > round {
129+ resPadVal = resPad [round ]
130+ }
131+ if pow4Pad != nil && len (pow4Pad ) > round {
132+ pow4PadVal = pow4Pad [round ]
133+ }
134+
135+ // Assign intermediateResult (skip round=0 as it is initialized to the blocks)
136+ if round > 0 {
137+ assignColumn (ctx .intermediateResult [round ].GetColID (), intermediateResWindow [round ], resPadVal )
138+ }
139+
140+ // Assign intermediatePow4
141+ assignColumn (ctx .intermediatePow4 [round ].GetColID (), intermediatePow4Window [round ], pow4PadVal )
142+ }
143+ }
144+
145+ // precomputePaddingValues precomputes padding values for constant regions
146+ func precomputePaddingValues (numRounds int ) ([]field.Element , []field.Element ) {
147+ resPad := make ([]field.Element , numRounds )
148+ pow4Pad := make ([]field.Element , numRounds )
149+ resPad [0 ].SetZero ()
150+
151+ var tmp field.Element
152+ tmp .Add (& resPad [0 ], & mimc .Constants [0 ])
153+ pow4Pad [0 ].Square (& tmp ).Square (& pow4Pad [0 ])
154+
155+ for r := 1 ; r < numRounds ; r ++ {
156+ tmp .Square (& pow4Pad [r - 1 ]).Square (& tmp )
157+ resPad [r ].Add (& resPad [r - 1 ], & mimc .Constants [r - 1 ])
158+ resPad [r ].Mul (& resPad [r ], & tmp )
159+ tmp .Add (& resPad [r ], & mimc .Constants [r ])
160+ pow4Pad [r ].Square (& tmp ).Square (& pow4Pad [r ])
161+ }
162+
163+ return resPad , pow4Pad
164+ }
165+
166+ // extractWindowSlices extracts window slices from the smart vectors
167+ func extractWindowSlices (oldStateSV , blocksSV smartvectors.SmartVector , l , h int ) ([]field.Element , []field.Element ) {
168+ var (
169+ oldStateWindow = smartvectors .IntoRegVec (oldStateSV )[l : l + h ]
170+ blocksWindow = smartvectors .IntoRegVec (blocksSV )[l : l + h ]
171+ )
172+ return oldStateWindow , blocksWindow
92173}
0 commit comments