Skip to content

Commit e418d45

Browse files
srinathln7AlexandreBelling
authored andcommitted
Perf(prover): uses padded-circular windows for the MiMC assignment (#770)
* use padded circular window to reduce runtime memory * (feat): use padded circular window for mimc col assignment
1 parent 9c60dbb commit e418d45

File tree

2 files changed

+147
-62
lines changed

2 files changed

+147
-62
lines changed

prover/backend/execution/prove.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ type Witness struct {
2727
}
2828

2929
func Prove(cfg *config.Config, req *Request, large bool) (*Response, error) {
30+
31+
// Set MonitorParams before any proving happens
32+
profiling.SetMonitorParams(cfg)
33+
3034
traces := &cfg.TracesLimits
3135
if large {
3236
traces = &cfg.TracesLimitsLarge

prover/protocol/compiler/mimc/assignment.go

Lines changed: 143 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,89 +4,170 @@ import (
44
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
55
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
66
"github.com/consensys/linea-monorepo/prover/maths/field"
7+
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
78
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
89
"github.com/consensys/linea-monorepo/prover/utils/parallel"
910
)
1011

11-
// assign assigns the columns to the prover runtime
12+
// assign: Assigns the columns to the prover runtime using PaddedCircularWindow
1213
func (ctx *mimcCtx) assign(run *wizard.ProverRuntime) {
1314

1415
var (
15-
oldState = ctx.oldStates.GetColAssignment(run).IntoRegVecSaveAlloc()
16-
blocks = ctx.blocks.GetColAssignment(run).IntoRegVecSaveAlloc()
17-
18-
// Initialize slices to hold intermediate results and intermediatePow4
19-
// The first entry is left empty for consistency with ctx.intermediateResult
20-
// We don't need to assign it because it is assigned already.
21-
intermediateRes = make([][]field.Element, len(ctx.intermediateResult))
22-
intermediatePow4 = make([][]field.Element, len(ctx.intermediateResult))
16+
oldStateSV = ctx.oldStates.GetColAssignment(run)
17+
blocksSV = ctx.blocks.GetColAssignment(run)
18+
totalRows = oldStateSV.Len()
19+
numRounds = len(ctx.intermediateResult)
2320
)
2421

25-
// Initialize intermediateRes and intermediatePow4 with correct lengths
26-
for i := range intermediateRes {
27-
// For each intermediate result, create a slice of field.Elements with length numRows
28-
intermediateRes[i] = make([]field.Element, len(oldState))
29-
intermediatePow4[i] = make([]field.Element, len(oldState))
30-
}
22+
offset, windowLen := identifyActiveWindow(oldStateSV, blocksSV, totalRows)
23+
oldStateWindow, blocksWindow := extractWindowSlices(oldStateSV, blocksSV, offset, windowLen)
24+
intermediateResWindow, intermediatePow4Window := computeIntermediateValues(numRounds, oldStateWindow, blocksWindow, windowLen)
3125

32-
// Set the initial intermediate res as the block itself
33-
intermediateRes[0] = blocks
26+
var resPad, pow4Pad []field.Element
3427

35-
// Compute intermediate values for each round
36-
for i := range ctx.intermediateResult {
37-
computeIntermediateValues(i, oldState, intermediateRes, intermediatePow4)
28+
// Precompute padding only when `PaddedCircularWindow` is tobe used
29+
// i.e. Whenever there is sparsity in the (oldState, blocks) pair
30+
if windowLen != totalRows {
31+
resPad, pow4Pad = precomputePaddingValues(numRounds)
3832
}
33+
assignOptimizedVectors(run, ctx, intermediateResWindow, intermediatePow4Window, resPad, pow4Pad, offset, totalRows)
34+
}
35+
36+
// identifyActiveWindow finds the smallest active window scanning through the oldState and blocks
37+
func identifyActiveWindow(oldStateSV, blocksSV smartvectors.SmartVector, totalRows int) (offset int, windowLen int) {
38+
// Convert to regular vectors to scan all elements
39+
var (
40+
oldState = smartvectors.IntoRegVec(oldStateSV)
41+
blocks = smartvectors.IntoRegVec(blocksSV)
42+
)
3943

40-
// Assign columns
41-
for i := range ctx.intermediateResult {
42-
// Assign computed values to the runtime
43-
if i > 0 {
44-
// Skip the first intermediate result
45-
// Recall that the first intermediate res is the block itself
46-
run.AssignColumn(
47-
ctx.intermediateResult[i].GetColID(),
48-
smartvectors.NewRegular(intermediateRes[i]),
49-
)
44+
// Initialize firstNonZero and lastNonZero indices to default values
45+
firstNonZero, lastNonZero := totalRows, -1
46+
for i := 0; i < totalRows; i++ {
47+
if !oldState[i].IsZero() || !blocks[i].IsZero() {
48+
firstNonZero = min(firstNonZero, i)
49+
lastNonZero = max(lastNonZero, i)
5050
}
51+
}
52+
53+
if firstNonZero <= lastNonZero {
54+
offset = firstNonZero
55+
windowLen = lastNonZero - firstNonZero + 1
56+
return offset, windowLen
57+
}
58+
// Default window => Full window
59+
return 0, totalRows
60+
}
5161

52-
// Assign intermediatePow4 to the runtime
53-
run.AssignColumn(
54-
ctx.intermediatePow4[i].GetColID(),
55-
smartvectors.NewRegular(intermediatePow4[i]),
56-
)
62+
// computeIntermediateValues computes intermediate values for the window
63+
func computeIntermediateValues(numRounds int, oldStateWindow, blocksWindow []field.Element, windowLen int) ([][]field.Element, [][]field.Element) {
64+
intermediateResWindow := make([][]field.Element, numRounds)
65+
intermediatePow4Window := make([][]field.Element, numRounds)
66+
for i := range intermediateResWindow {
67+
intermediateResWindow[i] = make([]field.Element, windowLen)
68+
intermediatePow4Window[i] = make([]field.Element, windowLen)
5769
}
5870

71+
// Initalize intermediateResWindow to the blocksWindow
72+
copy(intermediateResWindow[0], blocksWindow)
73+
74+
// r => round
75+
for r := 0; r < numRounds; r++ {
76+
parallel.Execute(windowLen, func(start, stop int) {
77+
for k := start; k < stop; k++ {
78+
if r == 0 {
79+
tmp := intermediateResWindow[0][k]
80+
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldStateWindow[k])
81+
intermediatePow4Window[0][k].Square(&tmp).Square(&intermediatePow4Window[0][k])
82+
} else {
83+
// For subsequent rounds, compute intermediate values based on previous results
84+
ark := mimc.Constants[r-1]
85+
nextArk := mimc.Constants[r]
86+
87+
tmp := intermediatePow4Window[r-1][k]
88+
tmp.Square(&tmp).Square(&tmp)
89+
90+
// Compute intermediate result using previous result and oldState
91+
intermediateResWindow[r][k] = intermediateResWindow[r-1][k]
92+
intermediateResWindow[r][k].Add(&intermediateResWindow[r][k], &ark).Add(&intermediateResWindow[r][k], &oldStateWindow[k])
93+
intermediateResWindow[r][k].Mul(&intermediateResWindow[r][k], &tmp)
94+
95+
// Compute intermediatePow4
96+
tmp = intermediateResWindow[r][k]
97+
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldStateWindow[k])
98+
intermediatePow4Window[r][k].Square(&tmp).Square(&intermediatePow4Window[r][k])
99+
}
100+
}
101+
})
102+
}
103+
return intermediateResWindow, intermediatePow4Window
59104
}
60105

61-
// computeIntermediateValues computes intermediate values for the given round
62-
func computeIntermediateValues(round int, oldState []field.Element, intermediateRes, intermediatePow4 [][]field.Element) {
63-
parallel.Execute(len(oldState), func(start, stop int) {
64-
for k := start; k < stop; k++ {
65-
if round == 0 {
66-
// For the first round, compute initial intermediatePow4
67-
tmp := intermediateRes[0][k]
68-
tmp.Add(&tmp, &mimc.Constants[0]).Add(&tmp, &oldState[k])
69-
intermediatePow4[0][k].Square(&tmp).Square(&intermediatePow4[0][k])
106+
// assignOptimizedVectors assigns optimized vectors to the prover runtime
107+
func assignOptimizedVectors(run *wizard.ProverRuntime, ctx *mimcCtx, intermediateResWindow, intermediatePow4Window [][]field.Element, resPad, pow4Pad []field.Element, offset, totalRows int) {
108+
for round := range ctx.intermediateResult {
109+
windowLen := len(intermediateResWindow[round])
110+
111+
// Full-length window: use Regular vector
112+
isRegSmartVec := windowLen == totalRows
113+
114+
// Helper function to assign a column with the appropriate smart vector
115+
assignColumn := func(colID ifaces.ColID, window []field.Element, padVal field.Element) {
116+
if isRegSmartVec {
117+
fullVec := make([]field.Element, totalRows)
118+
copy(fullVec[offset:offset+windowLen], window)
119+
run.AssignColumn(colID, smartvectors.NewRegular(fullVec))
70120
} else {
71-
// For subsequent rounds, compute intermediate values based on previous results
72-
ark := mimc.Constants[round-1]
73-
nextArk := mimc.Constants[round]
74-
75-
tmp := intermediatePow4[round-1][k]
76-
tmp.Square(&tmp).Square(&tmp)
77-
78-
// Compute intermediate result using previous result and oldState
79-
intermediateRes[round][k] = intermediateRes[round-1][k]
80-
intermediateRes[round][k].Add(&intermediateRes[round][k], &ark).Add(&intermediateRes[round][k], &oldState[k])
81-
intermediateRes[round][k].Mul(&intermediateRes[round][k], &tmp)
82-
83-
// Compute intermediatePow4
84-
tmp = intermediateRes[round][k]
85-
tmp.Add(&tmp, &nextArk).Add(&tmp, &oldState[k])
86-
tmp.Square(&tmp).Square(&tmp)
87-
intermediatePow4[round][k] = tmp
121+
// Partial window: use PaddedCircularWindow with lazily evaluated padding
122+
run.AssignColumn(colID, smartvectors.NewPaddedCircularWindow(window, padVal, offset, totalRows))
88123
}
89124
}
90-
})
91125

126+
// Determine padding values
127+
var resPadVal, pow4PadVal field.Element
128+
if resPad != nil && len(resPad) > round {
129+
resPadVal = resPad[round]
130+
}
131+
if pow4Pad != nil && len(pow4Pad) > round {
132+
pow4PadVal = pow4Pad[round]
133+
}
134+
135+
// Assign intermediateResult (skip round=0 as it is initialized to the blocks)
136+
if round > 0 {
137+
assignColumn(ctx.intermediateResult[round].GetColID(), intermediateResWindow[round], resPadVal)
138+
}
139+
140+
// Assign intermediatePow4
141+
assignColumn(ctx.intermediatePow4[round].GetColID(), intermediatePow4Window[round], pow4PadVal)
142+
}
143+
}
144+
145+
// precomputePaddingValues precomputes padding values for constant regions
146+
func precomputePaddingValues(numRounds int) ([]field.Element, []field.Element) {
147+
resPad := make([]field.Element, numRounds)
148+
pow4Pad := make([]field.Element, numRounds)
149+
resPad[0].SetZero()
150+
151+
var tmp field.Element
152+
tmp.Add(&resPad[0], &mimc.Constants[0])
153+
pow4Pad[0].Square(&tmp).Square(&pow4Pad[0])
154+
155+
for r := 1; r < numRounds; r++ {
156+
tmp.Square(&pow4Pad[r-1]).Square(&tmp)
157+
resPad[r].Add(&resPad[r-1], &mimc.Constants[r-1])
158+
resPad[r].Mul(&resPad[r], &tmp)
159+
tmp.Add(&resPad[r], &mimc.Constants[r])
160+
pow4Pad[r].Square(&tmp).Square(&pow4Pad[r])
161+
}
162+
163+
return resPad, pow4Pad
164+
}
165+
166+
// extractWindowSlices extracts window slices from the smart vectors
167+
func extractWindowSlices(oldStateSV, blocksSV smartvectors.SmartVector, l, h int) ([]field.Element, []field.Element) {
168+
var (
169+
oldStateWindow = smartvectors.IntoRegVec(oldStateSV)[l : l+h]
170+
blocksWindow = smartvectors.IntoRegVec(blocksSV)[l : l+h]
171+
)
172+
return oldStateWindow, blocksWindow
92173
}

0 commit comments

Comments
 (0)