Skip to content

Commit fb1aebd

Browse files
authored
feat: cherry pick defcall into v11 (#1331)
This adds support for parsing defcall declarations, along with an appropriate AST node. The syntax supports source selectors, and updates encoding / decoding as necessary. This adds support for the resolution, processing and typing stages for the DefCall construct. An initial implementation of translation is also included, but this does not yet actually translate the call. This updates the trace propagation algorithm to recognise function call's. These can then trigger calls into assembly functions. One problem is that these are evaluated in the base field, meaning they can overflow (in principle). To address this, a check is used to protect against unexpected overflow during trace propagation of function calls.
1 parent 0d25749 commit fb1aebd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+6977
-8
lines changed

pkg/asm/propagate.go

Lines changed: 160 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ import (
1919

2020
"github.com/consensys/go-corset/pkg/asm/io"
2121
"github.com/consensys/go-corset/pkg/ir"
22+
"github.com/consensys/go-corset/pkg/ir/mir"
2223
"github.com/consensys/go-corset/pkg/schema"
24+
sc "github.com/consensys/go-corset/pkg/schema"
25+
"github.com/consensys/go-corset/pkg/trace"
2326
"github.com/consensys/go-corset/pkg/trace/lt"
2427
"github.com/consensys/go-corset/pkg/util/collection/array"
2528
"github.com/consensys/go-corset/pkg/util/field"
@@ -78,7 +81,7 @@ func Propagate[F field.Element[F], T io.Instruction[T]](p MixedProgram[F, T], tr
7881
// Construct suitable executior for the given program
7982
var (
8083
errors []error
81-
n = len(p.program.Functions())
84+
n = uint(len(p.program.Functions()))
8285
//
8386
executor = io.NewExecutor(p.program)
8487
// Clone heap in trace file, since will mutate this.
@@ -91,7 +94,7 @@ func Propagate[F field.Element[F], T io.Instruction[T]](p MixedProgram[F, T], tr
9194
return lt.TraceFile{}, errors
9295
}
9396
// Write seed instances
94-
errors = writeInstances(p.program, trace.Modules[:n], executor)
97+
errors = writeInstances(p, n, trace.Modules, executor)
9598
// Read out generated instances
9699
modules := readInstances(&heap, p.program, executor)
97100
// Append external modules (which are unaffected by propagation).
@@ -103,15 +106,24 @@ func Propagate[F field.Element[F], T io.Instruction[T]](p MixedProgram[F, T], tr
103106
// WriteInstances writes all of the instances defined in the given trace columns
104107
// into the executor which, in turn, forces it to execute the relevant
105108
// functions, and functions they call, etc.
106-
func writeInstances[T io.Instruction[T]](p io.Program[T], trace []lt.Module[word.BigEndian],
107-
executor *io.Executor[T]) []error {
109+
func writeInstances[F field.Element[F], T io.Instruction[T]](p MixedProgram[F, T], n uint,
110+
trace []lt.Module[word.BigEndian], executor *io.Executor[T]) []error {
108111
//
109112
var errors []error
110113
//
111-
for i, m := range trace {
112-
errs := writeFunctionInstances(uint(i), p, m, executor)
114+
for i, m := range trace[:n] {
115+
errs := writeFunctionInstances(uint(i), p.program, m, executor)
113116
errors = append(errors, errs...)
114117
}
118+
// Write all from non-assembly modules
119+
for i, m := range trace[n:] {
120+
var extern = p.externs[i]
121+
// Write instances from any external calls
122+
for _, call := range extractExternalCalls(extern) {
123+
errs := writeExternCall(call, p.program, m, executor)
124+
errors = append(errors, errs...)
125+
}
126+
}
115127
//
116128
return errors
117129
}
@@ -142,6 +154,68 @@ func writeFunctionInstances[T io.Instruction[T]](fid uint, p io.Program[T], mod
142154
return errors
143155
}
144156

157+
// Extract any external function calls found within the given module, returning
158+
// them as an array.
159+
func extractExternalCalls[F field.Element[F], M sc.Module[F]](extern M) []mir.FunctionCall[F] {
160+
var calls []mir.FunctionCall[F]
161+
//
162+
for iter := extern.Constraints(); iter.HasNext(); {
163+
c := iter.Next()
164+
// This should always hold
165+
if hc, ok := c.(mir.Constraint[F]); ok {
166+
// Check whether its a call or not
167+
if call, ok := hc.Unwrap().(mir.FunctionCall[F]); ok {
168+
// Yes, so record it
169+
calls = append(calls, call)
170+
}
171+
}
172+
}
173+
//
174+
return calls
175+
}
176+
177+
// Write any function instances arising from the given call.
178+
func writeExternCall[F field.Element[F], T io.Instruction[T]](call mir.FunctionCall[F], p io.Program[T], mod RawModule,
179+
executor *io.Executor[T]) []error {
180+
//
181+
var (
182+
trMod = &ltModuleAdaptor[F]{mod}
183+
height = mod.Height()
184+
fn = p.Function(call.Callee)
185+
inputs = make([]big.Int, fn.NumInputs())
186+
outputs = make([]big.Int, fn.NumOutputs())
187+
errors []error
188+
)
189+
//
190+
if call.Selector.HasValue() {
191+
var selector = call.Selector.Unwrap()
192+
// Invoke each user-defined instance in turn
193+
for i := range height {
194+
// execute if selector enabled
195+
if enabled, _, err := selector.TestAt(int(i), trMod, nil); enabled {
196+
// Extract external columns
197+
extractExternColumns(int(i), call, trMod, inputs, outputs)
198+
// Execute function call to produce outputs
199+
errs := executeAndCheck(call.Callee, fn.Name(), inputs, outputs, executor)
200+
errors = append(errors, errs...)
201+
} else if err != nil {
202+
errors = append(errors, err)
203+
}
204+
}
205+
} else {
206+
// Invoke each user-defined instance in turn
207+
for i := range height {
208+
// Extract external columns
209+
extractExternColumns(int(i), call, trMod, inputs, outputs)
210+
// Execute function call to produce outputs
211+
errs := executeAndCheck(call.Callee, fn.Name(), inputs, outputs, executor)
212+
errors = append(errors, errs...)
213+
}
214+
}
215+
//
216+
return errors
217+
}
218+
145219
func executeAndCheck[T io.Instruction[T]](fid uint, name string, inputs, outputs []big.Int,
146220
executor *io.Executor[T]) []error {
147221
var (
@@ -195,6 +269,34 @@ func extractFunctionColumns(row uint, mod RawModule, inputs, outputs []big.Int)
195269
}
196270
}
197271

272+
func extractExternColumns[F field.Element[F]](row int, call mir.FunctionCall[F], mod trace.Module[F],
273+
inputs, outputs []big.Int) []error {
274+
//
275+
// Extract function arguments
276+
errs1 := extractExternTerms(row, call.Arguments, mod, inputs)
277+
// Extract function returns
278+
errs2 := extractExternTerms(row, call.Returns, mod, outputs)
279+
//
280+
return append(errs1, errs2...)
281+
}
282+
283+
func extractExternTerms[F field.Element[F]](row int, terms []mir.Term[F], mod trace.Module[F], values []big.Int,
284+
) []error {
285+
var errors []error
286+
//
287+
for i, arg := range terms {
288+
var (
289+
ith big.Int
290+
val, err = arg.EvalAt(row, mod, nil)
291+
)
292+
ith.SetBytes(val.Bytes())
293+
values[i] = ith
294+
//
295+
errors = append(errors, err)
296+
}
297+
//
298+
return errors
299+
}
198300
func extractFunctionPadding(registers []schema.Register, inputs, outputs []big.Int) {
199301
var numInputs = len(inputs)
200302
//
@@ -279,3 +381,55 @@ func toArgumentString(args []big.Int) string {
279381
//
280382
return builder.String()
281383
}
384+
385+
// The purpose of the lt adaptor is to make an lt.TraceFile look like a Trace.
386+
// In general, this is not safe. However, we use this once we already know that
387+
// the trace has been aligned. Also, it is only used in a specific context.
388+
type ltModuleAdaptor[F field.Element[F]] struct {
389+
module lt.Module[word.BigEndian]
390+
}
391+
392+
func (p *ltModuleAdaptor[F]) Name() string {
393+
return p.module.Name
394+
}
395+
396+
func (p *ltModuleAdaptor[F]) Width() uint {
397+
return uint(len(p.module.Columns))
398+
}
399+
400+
func (p *ltModuleAdaptor[F]) Height() uint {
401+
return p.module.Height()
402+
}
403+
404+
func (p *ltModuleAdaptor[F]) Column(cid uint) trace.Column[F] {
405+
return &ltColumnAdaptor[F]{p.module.Columns[cid]}
406+
}
407+
408+
func (p *ltModuleAdaptor[F]) ColumnOf(col string) trace.Column[F] {
409+
panic("unsupported operation")
410+
}
411+
412+
type ltColumnAdaptor[F field.Element[F]] struct {
413+
column lt.Column[word.BigEndian]
414+
}
415+
416+
func (p *ltColumnAdaptor[F]) Name() string {
417+
return p.column.Name
418+
}
419+
420+
func (p *ltColumnAdaptor[F]) Get(row int) F {
421+
var (
422+
v = p.column.Data.Get(uint(row))
423+
w F
424+
)
425+
// Convert
426+
return w.SetBytes(v.Bytes())
427+
}
428+
429+
func (p *ltColumnAdaptor[F]) Data() array.Array[F] {
430+
panic("unsupported operation")
431+
}
432+
433+
func (p *ltColumnAdaptor[F]) Padding() F {
434+
panic("unsupported operation")
435+
}

pkg/corset/ast/declaration.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,97 @@ func (p *DefAlias) Lisp() sexp.SExp {
147147
return sexp.NewSymbol(p.Name)
148148
}
149149

150+
// ============================================================================
151+
// defcall
152+
// ============================================================================
153+
154+
// DefCall captures a function call between a lisp module and an assembly
155+
// function. A key feature of this is that it triggers trace propagation.
156+
type DefCall struct {
157+
// Returns for the call
158+
Returns []Expr
159+
// Function being called
160+
Function string
161+
// Arguments for the call
162+
Arguments []Expr
163+
// Optional source selector
164+
Selector util.Option[Expr]
165+
// determines whether or not this has been finalised.
166+
finalised bool
167+
}
168+
169+
// NewDefCall creates a new (unfinalised) function call.
170+
func NewDefCall(returns []Expr, fun string, args []Expr, selector util.Option[Expr]) *DefCall {
171+
//
172+
return &DefCall{returns, fun, args, selector, false}
173+
}
174+
175+
// Definitions returns the set of symbols defined by this declaration. Observe
176+
// that these may not yet have been finalised.
177+
func (p *DefCall) Definitions() iter.Iterator[SymbolDefinition] {
178+
return iter.NewArrayIterator[SymbolDefinition](nil)
179+
}
180+
181+
// Dependencies needed to signal declaration.
182+
func (p *DefCall) Dependencies() iter.Iterator[Symbol] {
183+
var deps []Symbol
184+
//
185+
deps = append(deps, DependenciesOfExpressions(p.Arguments)...)
186+
deps = append(deps, DependenciesOfExpressions(p.Returns)...)
187+
// Include selector dependencies (if applicable)
188+
if p.Selector.HasValue() {
189+
deps = append(deps, p.Selector.Unwrap().Dependencies()...)
190+
}
191+
// Combine deps
192+
return iter.NewArrayIterator(deps)
193+
}
194+
195+
// Defines checks whether this declaration defines the given symbol. The symbol
196+
// in question needs to have been resolved already for this to make sense.
197+
func (p *DefCall) Defines(symbol Symbol) bool {
198+
return false
199+
}
200+
201+
// IsFinalised checks whether this declaration has already been finalised. If
202+
// so, then we don't need to finalise it again.
203+
func (p *DefCall) IsFinalised() bool {
204+
return p.finalised
205+
}
206+
207+
// Finalise this declaration, which means that all source and target expressions
208+
// have been resolved.
209+
func (p *DefCall) Finalise() {
210+
p.finalised = true
211+
}
212+
213+
// Lisp converts this node into its lisp representation. This is primarily used
214+
// for debugging purposes.
215+
func (p *DefCall) Lisp() sexp.SExp {
216+
returns := make([]sexp.SExp, len(p.Returns))
217+
args := make([]sexp.SExp, len(p.Arguments))
218+
// Returns
219+
for i, t := range p.Returns {
220+
returns[i] = t.Lisp()
221+
}
222+
// Arguments
223+
for i, t := range p.Arguments {
224+
args[i] = t.Lisp()
225+
}
226+
//
227+
list := sexp.NewList([]sexp.SExp{
228+
sexp.NewSymbol("defcall"),
229+
sexp.NewList(returns),
230+
sexp.NewSymbol(p.Function),
231+
sexp.NewList(args),
232+
})
233+
// Include selector (if applicable)
234+
if p.Selector.HasValue() {
235+
list.Append(p.Selector.Unwrap().Lisp())
236+
}
237+
//
238+
return list
239+
}
240+
150241
// ============================================================================
151242
// defcolumns
152243
// ============================================================================

pkg/corset/compiler/parser.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ func (p *Parser) parseDeclaration(module file.Path, s *sexp.List) (ast.Declarati
287287
//
288288
if s.MatchSymbols(1, "defalias") {
289289
decl, errors = p.parseDefAlias(s.Elements)
290+
} else if s.Len() == 4 && s.MatchSymbols(1, "defcall") {
291+
decl, errors = p.parseDefCall(false, s.Elements)
292+
} else if s.Len() == 5 && s.MatchSymbols(1, "defcall") {
293+
decl, errors = p.parseDefCall(true, s.Elements)
290294
} else if s.MatchSymbols(1, "defcolumns") {
291295
decl, errors = p.parseDefColumns(module, s)
292296
} else if s.Len() == 3 && s.MatchSymbols(1, "defcomputed") {
@@ -875,6 +879,37 @@ func (p *Parser) parseDefInterleavedSourceArray(source *sexp.Array) (ast.TypedSy
875879
return nil, errors
876880
}
877881

882+
func (p *Parser) parseDefCall(hasSelector bool, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
883+
var (
884+
errors []SyntaxError
885+
returns, retErrors = p.parseDefLookupSources("return", elements[1])
886+
args, argErrors = p.parseDefLookupSources("argument", elements[3])
887+
selector = util.None[ast.Expr]()
888+
)
889+
// Sanity check function name
890+
if !isIdentifier(elements[2]) {
891+
return nil, p.translator.SyntaxErrors(elements[2], "malformed function name")
892+
}
893+
// Extract function name
894+
fun := elements[2].AsSymbol().Value
895+
// Combine any and all errors
896+
errors = append(errors, argErrors...)
897+
errors = append(errors, retErrors...)
898+
// Parse selector (if applicable)
899+
if hasSelector {
900+
sel, errs := p.translator.Translate(elements[4])
901+
selector = util.Some(sel)
902+
//
903+
errors = append(errors, errs...)
904+
}
905+
// Error check
906+
if len(errors) != 0 {
907+
return nil, errors
908+
}
909+
//
910+
return ast.NewDefCall(returns, fun, args, selector), nil
911+
}
912+
878913
// Parse a lookup declaration
879914
func (p *Parser) parseDefLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
880915
// Extract items

0 commit comments

Comments
 (0)