Skip to content

Commit ace3ebd

Browse files
authored
fix: race condition during propagation expansion (#1181)
* minor tweaks * Split out FunctionTrace from Executor This splits out a separate object for capturing a function's trace within the executor. The benefit of this is that we can now lock data within an individual trace. * implement thread-safe FunctionTrace This puts in place a thread-safe implementation of the executor's FunctionTrace object. This allows locking on a function-by-function basis (to reduce contention) and also uses a read-write lock (since we have a cache).
1 parent 2cb6099 commit ace3ebd

File tree

7 files changed

+206
-108
lines changed

7 files changed

+206
-108
lines changed

pkg/asm/compiler/compiler.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -282,17 +282,22 @@ func (p *Compiler[F, T, E, M]) initBuses(caller uint, fn MicroFunction) bit.Set
282282
for _, bus := range fn.Buses() {
283283
// Callee represents the function being called by this Bus.
284284
var (
285-
name = fmt.Sprintf("%s=>%s", fn.Name(), bus.Name)
286-
callerBus = p.buses[caller].ColumnsOf(bus.AddressData()...)
287-
callerLines = make([]E, len(callerBus))
288-
calleeBus = p.buses[bus.BusId].Bus()
289-
calleeLines = make([]E, len(calleeBus))
290-
calleeEnable *E
285+
name = fmt.Sprintf("%s=>%s", fn.Name(), bus.Name)
286+
callerAddress = p.buses[caller].ColumnsOf(bus.Address()...)
287+
callerData = p.buses[caller].ColumnsOf(bus.Data()...)
288+
callerLines = make([]E, len(callerAddress)+len(callerData))
289+
calleeBus = p.buses[bus.BusId].Bus()
290+
calleeLines = make([]E, len(calleeBus))
291+
calleeEnable *E
291292
)
292-
// Initialise caller lines
293-
for i, r := range callerBus {
293+
// Initialise caller address lines
294+
for i, r := range callerAddress {
294295
callerLines[i] = Variable[T, E](r, 0)
295296
}
297+
// Initialise caller data lines
298+
for i, r := range callerData {
299+
callerLines[i+len(callerAddress)] = Variable[T, E](r, 0)
300+
}
296301
// Initialise callee lines
297302
for i, r := range calleeBus {
298303
calleeLines[i] = Variable[T, E](r, 0)
@@ -305,7 +310,11 @@ func (p *Compiler[F, T, E, M]) initBuses(caller uint, fn MicroFunction) bit.Set
305310
// Add lookup constraint
306311
module.NewLookup(name, callerLines, bus.BusId, calleeLines, calleeEnable)
307312
// Mark caller address / data lines as io registers
308-
for _, r := range bus.AddressData() {
313+
for _, r := range bus.Address() {
314+
ioRegisters.Insert(r.Unwrap())
315+
}
316+
// Mark caller data lines as io registers
317+
for _, r := range bus.Data() {
309318
ioRegisters.Insert(r.Unwrap())
310319
}
311320
}

pkg/asm/io/bus.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,6 @@ func (p *Bus) Data() []RegisterId {
6868
return p.DataLines
6969
}
7070

71-
// AddressData returns the "address" and "data" lines for this bus (in that
72-
// order). That is, the registers which hold the various components of the
73-
// address.
74-
func (p *Bus) AddressData() []RegisterId {
75-
return append(p.AddressLines, p.DataLines...)
76-
}
77-
7871
// Split this micro code using registers of arbirary width into one or more
7972
// micro codes using registers of a fixed maximum width.
8073
func (p *Bus) Split(env schema.RegisterAllocator) Bus {

pkg/asm/io/executor.go

Lines changed: 113 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,61 +15,35 @@ package io
1515
import (
1616
"math"
1717
"math/big"
18+
"sync"
1819

1920
"github.com/consensys/go-corset/pkg/schema"
2021
"github.com/consensys/go-corset/pkg/util/collection/set"
2122
)
2223

23-
// FunctionInstance captures the mapping from inputs (i.e. parameters) to outputs (i.e.
24-
// returns) for a particular instance of a given function.
25-
type FunctionInstance struct {
26-
ninputs uint
27-
state []big.Int
28-
}
29-
30-
// Cmp comparator for the I/O registers of a particular function instance.
31-
// Observe that, since functions are always deterministic, this only considers
32-
// the inputs (as the outputs follow directly from this).
33-
func (p FunctionInstance) Cmp(other FunctionInstance) int {
34-
for i := range p.ninputs {
35-
if c := p.state[i].Cmp(&other.state[i]); c != 0 {
36-
return c
37-
}
38-
}
39-
//
40-
return 0
41-
}
42-
43-
// Outputs returns the output values for this function instance.
44-
func (p FunctionInstance) Outputs() []big.Int {
45-
return p.state[p.ninputs:]
46-
}
47-
48-
// Get value of given input or output argument for this instance.
49-
func (p FunctionInstance) Get(arg uint) big.Int {
50-
return p.state[arg]
51-
}
52-
5324
// Executor provides a mechanism for executing a program efficiently and
5425
// generating a suitable top-level trace. Executor implements the io.Map
5526
// interface.
5627
type Executor[T Instruction[T]] struct {
57-
program Program[T]
58-
states []set.AnySortedSet[FunctionInstance]
28+
functions []*FunctionTrace[T]
5929
}
6030

6131
// NewExecutor constructs a new executor.
6232
func NewExecutor[T Instruction[T]](program Program[T]) *Executor[T] {
63-
// Construct initially empty set of states
64-
states := make([]set.AnySortedSet[FunctionInstance], len(program.Functions()))
33+
// Initialise executor traces
34+
traces := make([]*FunctionTrace[T], len(program.Functions()))
35+
//
36+
for i := range traces {
37+
traces[i] = NewFunctionTrace(program.functions[i])
38+
}
6539
// Construct new executor
66-
return &Executor[T]{program, states}
40+
return &Executor[T]{traces}
6741
}
6842

6943
// Instance returns a valid instance of the given bus.
7044
func (p *Executor[T]) Instance(bus uint) FunctionInstance {
7145
var (
72-
fn = p.program.Function(bus)
46+
fn = p.functions[bus].fn
7347
inputs = make([]big.Int, fn.NumInputs())
7448
)
7549
// Intialise inputs values
@@ -82,27 +56,17 @@ func (p *Executor[T]) Instance(bus uint) FunctionInstance {
8256
inputs[i] = *ith.Set(&reg.Padding)
8357
}
8458
// Compute function instance
85-
return p.call(bus, inputs)
59+
return p.functions[bus].Call(inputs, p)
8660
}
8761

8862
// Read implementation for the io.Map interface.
8963
func (p *Executor[T]) Read(bus uint, address []big.Int) []big.Int {
90-
var (
91-
iostate = FunctionInstance{uint(len(address)), address}
92-
states = p.states[bus]
93-
)
94-
// Check whether this instance has already been computed.
95-
if index := states.Find(iostate); index != math.MaxUint {
96-
// Yes, therefore return precomputed outputs
97-
return states[index].Outputs()
98-
}
99-
// Execute function to determine new outputs.
100-
return p.call(bus, address).Outputs()
64+
return p.functions[bus].Call(address, p).Outputs()
10165
}
10266

10367
// Instances returns accrued function instances for the given bus.
10468
func (p *Executor[T]) Instances(bus uint) []FunctionInstance {
105-
return p.states[bus]
69+
return p.functions[bus].instances
10670
}
10771

10872
// Write implementation for the io.Map interface.
@@ -111,15 +75,72 @@ func (p *Executor[T]) Write(bus uint, address []big.Int, values []big.Int) {
11175
panic("unsupported operation")
11276
}
11377

114-
func (p *Executor[T]) call(bus uint, inputs []big.Int) FunctionInstance {
78+
// ============================================================================
79+
// FunctionTrace
80+
// ============================================================================
81+
82+
// FunctionTrace captures all instances for a given function, and provides a
83+
// (thread-safe) API for calling to compute its output for a given set of
84+
// inputs.
85+
type FunctionTrace[T Instruction[T]] struct {
86+
// Function whose instances are captured here
87+
fn *Function[T]
88+
// Cached instances of the given function
89+
instances set.AnySortedSet[FunctionInstance]
90+
// mutex required to ensure thread safety.
91+
mux sync.RWMutex
92+
}
93+
94+
// NewFunctionTrace constructs an empty trace for a given function.
95+
func NewFunctionTrace[T Instruction[T]](fn *Function[T]) *FunctionTrace[T] {
96+
instances := set.NewAnySortedSet[FunctionInstance]()
97+
//
98+
return &FunctionTrace[T]{
99+
fn: fn,
100+
instances: *instances,
101+
}
102+
}
103+
104+
// Call this function to determine its outputs for a given set of inputs. If
105+
// this instance has been seen before, it will simply return that. Otherwise,
106+
// it will execute the function to determine the correct outputs.
107+
func (p *FunctionTrace[T]) Call(inputs []big.Int, iomap Map) FunctionInstance {
108+
var iostate = FunctionInstance{uint(len(inputs)), inputs}
109+
// Obtain read lock
110+
p.mux.RLock()
111+
// Look for cached instance
112+
index := p.instances.Find(iostate)
113+
// Release read lock
114+
p.mux.RUnlock()
115+
// Check for cache hit.
116+
if index != math.MaxUint {
117+
// Yes, therefore return precomputed outputs
118+
return p.instances[index]
119+
}
120+
// Execute function to determine new outputs.
121+
return p.executeCall(inputs, iomap)
122+
}
123+
124+
// Execute this function for a given set of inputs to determine its outputs and
125+
// produce a given instance. The created instance is recorded within the trace
126+
// so it can be reused rather than recomputed in the future. This function is
127+
// thread-safe, and will acquire the write lock on the cached instances
128+
// momentarily to insert the new instance.
129+
//
130+
// NOTE: this does not attempt any form of thread blocking (e.g. when a desired
131+
// instance if being computed by another thread). Instead, it eagerly computes
132+
// instances --- even if that means, occasionally, an instance is computed more
133+
// than once. This is safe since instances are always deterministic (i.e. same
134+
// output for a given input).
135+
func (p *FunctionTrace[T]) executeCall(inputs []big.Int, iomap Map) FunctionInstance {
115136
var (
116-
fn = p.program.Function(bus)
137+
fn = p.fn
117138
// Determine how many I/O registers
118139
nio = fn.NumInputs() + fn.NumOutputs()
119140
//
120141
pc = uint(0)
121142
//
122-
state = InitialState(inputs, fn.Registers(), fn.Buses(), p)
143+
state = InitialState(inputs, fn.Registers(), fn.Buses(), iomap)
123144
)
124145
// Keep executing until we're done.
125146
for pc != RETURN && pc != FAIL {
@@ -131,7 +152,46 @@ func (p *Executor[T]) call(bus uint, inputs []big.Int) FunctionInstance {
131152
}
132153
// Cache I/O instance
133154
instance := FunctionInstance{fn.NumInputs(), state.state[:nio]}
134-
p.states[bus].Insert(instance)
155+
// Obtain write lock
156+
p.mux.Lock()
157+
// Insert new instance
158+
p.instances.Insert(instance)
159+
// Release write lock
160+
p.mux.Unlock()
135161
// Done
136162
return instance
137163
}
164+
165+
// ============================================================================
166+
// FunctionInstance
167+
// ============================================================================
168+
169+
// FunctionInstance captures the mapping from inputs (i.e. parameters) to outputs (i.e.
170+
// returns) for a particular instance of a given function.
171+
type FunctionInstance struct {
172+
ninputs uint
173+
state []big.Int
174+
}
175+
176+
// Cmp comparator for the I/O registers of a particular function instance.
177+
// Observe that, since functions are always deterministic, this only considers
178+
// the inputs (as the outputs follow directly from this).
179+
func (p FunctionInstance) Cmp(other FunctionInstance) int {
180+
for i := range p.ninputs {
181+
if c := p.state[i].Cmp(&other.state[i]); c != 0 {
182+
return c
183+
}
184+
}
185+
//
186+
return 0
187+
}
188+
189+
// Outputs returns the output values for this function instance.
190+
func (p FunctionInstance) Outputs() []big.Int {
191+
return p.state[p.ninputs:]
192+
}
193+
194+
// Get value of given input or output argument for this instance.
195+
func (p FunctionInstance) Get(arg uint) big.Int {
196+
return p.state[arg]
197+
}

pkg/asm/io/program.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ func initialState(registers []Register, buses []Bus, iomap Map) State {
9393
}
9494
// Initialie I/O buses
9595
for _, bus := range buses {
96-
// Initialise state from padding
97-
for _, rid := range bus.AddressData() {
96+
// Initialise address lines from padding
97+
for _, rid := range bus.Address() {
98+
state[rid.Unwrap()] = registers[rid.Unwrap()].Padding
99+
}
100+
// Initialise data lines from padding
101+
for _, rid := range bus.Data() {
98102
state[rid.Unwrap()] = registers[rid.Unwrap()].Padding
99103
}
100104
}

pkg/asm/io/state.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ type State struct {
3939
pc uint
4040
// Terminate indicates this is a terminating state
4141
terminal bool
42+
// Number of input registers
43+
numInputs uint
44+
// Number of output registers
45+
numOutputs uint
4246
// Values for each register in this state excluding the program counter
4347
// (since this is held above). Thus, this array has one less item than
4448
// registers.
@@ -53,15 +57,39 @@ type State struct {
5357
// EmptyState constructs an initially empty state at the given PC value. One
5458
// can then set register values as needed via Store.
5559
func EmptyState(pc uint, registers []schema.Register, io Map) State {
56-
var state = make([]big.Int, len(registers))
60+
var (
61+
state = make([]big.Int, len(registers))
62+
numInputs uint
63+
numOutputs uint
64+
)
65+
//
66+
for _, r := range registers {
67+
if r.IsInput() {
68+
numInputs++
69+
} else if r.IsOutput() {
70+
numOutputs++
71+
}
72+
}
5773
// Construct state
58-
return State{pc, false, state, registers, io}
74+
return State{pc, false, numInputs, numOutputs, state, registers, io}
5975
}
6076

6177
// NewState constructs a new state instance from the given state values.
6278
func NewState(state []big.Int, registers []schema.Register, io Map) State {
79+
var (
80+
numInputs uint
81+
numOutputs uint
82+
)
83+
//
84+
for _, r := range registers {
85+
if r.IsInput() {
86+
numInputs++
87+
} else if r.IsOutput() {
88+
numOutputs++
89+
}
90+
}
6391
// Construct state
64-
return State{0, false, state, registers, io}
92+
return State{0, false, numInputs, numOutputs, state, registers, io}
6593
}
6694

6795
// InitialState constructs a suitable initial state for executing a given
@@ -72,8 +100,12 @@ func InitialState(inputs []big.Int, registers []schema.Register, buses []Bus, io
72100
copy(state, inputs)
73101
// Initialie I/O buses
74102
for _, bus := range buses {
75-
// Initialise state from padding
76-
for _, rid := range bus.AddressData() {
103+
// Initialise address lines from padding
104+
for _, rid := range bus.Address() {
105+
state[rid.Unwrap()] = registers[rid.Unwrap()].Padding
106+
}
107+
// Initialise data lines from padding
108+
for _, rid := range bus.Data() {
77109
state[rid.Unwrap()] = registers[rid.Unwrap()].Padding
78110
}
79111
}
@@ -86,6 +118,8 @@ func (p *State) Clone() State {
86118
return State{
87119
p.pc,
88120
p.terminal,
121+
p.numInputs,
122+
p.numOutputs,
89123
slices.Clone(p.state),
90124
p.registers,
91125
p.io,
@@ -120,12 +154,10 @@ func (p *State) In(bus Bus) {
120154
// Outputs extracts values from output registers of the given state.
121155
func (p *State) Outputs() []big.Int {
122156
// Construct outputs
123-
outputs := make([]big.Int, 0)
157+
outputs := make([]big.Int, p.numOutputs)
124158
//
125-
for i, reg := range p.registers {
126-
if reg.IsOutput() {
127-
outputs = append(outputs, p.state[i])
128-
}
159+
for i := range p.numOutputs {
160+
outputs[i] = p.state[i+p.numInputs]
129161
}
130162
//
131163
return outputs

0 commit comments

Comments
 (0)