From c4c396a7300651c2e31249fe9d7a31b11e68ea81 Mon Sep 17 00:00:00 2001 From: ascandone Date: Tue, 24 Sep 2024 12:19:19 +0200 Subject: [PATCH] moved main file and updated goreleaser yaml WIP more test for get balances changed public API removed duplicate struct added ctx param to interpreter run re-export function export more types removed some todos added more tests for balance moved main file run go mod tidy updated gorelease config handle big ints --- .goreleaser.yaml | 1 + go.mod | 2 +- internal/cmd/run.go | 23 +- internal/interpreter/batch_balances_query.go | 142 +++++++ internal/interpreter/evaluate_lit.go | 74 ++++ internal/interpreter/interpreter.go | 373 ++++++++---------- internal/interpreter/interpreter_error.go | 18 + .../interpreter/interpreter_errors_test.go | 25 +- internal/interpreter/interpreter_test.go | 78 ++-- internal/interpreter/utils.go | 11 + main.go => internal/numscript/numscript.go | 0 .../parser/__snapshots__/parser_test.snap | 90 ++++- internal/parser/ast.go | 6 +- internal/parser/parser.go | 36 +- numscript.go | 60 +++ numscript_test.go | 279 +++++++++++++ 16 files changed, 925 insertions(+), 293 deletions(-) create mode 100644 internal/interpreter/batch_balances_query.go create mode 100644 internal/interpreter/evaluate_lit.go create mode 100644 internal/interpreter/utils.go rename main.go => internal/numscript/numscript.go (100%) create mode 100644 numscript.go create mode 100644 numscript_test.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 984b46f..c5dc479 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -16,6 +16,7 @@ before: builds: - env: - CGO_ENABLED=0 + main: ./internal/numscript/numscript.go goos: - linux - windows diff --git a/go.mod b/go.mod index c55dd35..e474b60 100644 --- a/go.mod +++ b/go.mod @@ -32,5 +32,5 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect - golang.org/x/exp v0.0.0-20240707233637-46b078467d37 // indirect + golang.org/x/exp v0.0.0-20240707233637-46b078467d37 ) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index defe8b8..467afcd 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "encoding/json" "fmt" "io" @@ -26,10 +27,10 @@ var runStdinFlag bool var runOutFormatOpt string type inputOpts struct { - Script string `json:"script"` - Variables map[string]string `json:"variables"` - Meta map[string]interpreter.Metadata `json:"metadata"` - Balances interpreter.StaticStore `json:"balances"` + Script string `json:"script"` + Variables map[string]string `json:"variables"` + Meta interpreter.AccountsMetadata `json:"metadata"` + Balances interpreter.Balances `json:"balances"` } func (o *inputOpts) fromRaw() { @@ -100,8 +101,8 @@ func (o *inputOpts) fromOptions(path string) { func run(path string) { opt := inputOpts{ Variables: make(map[string]string), - Meta: make(map[string]interpreter.Metadata), - Balances: make(interpreter.StaticStore), + Meta: make(interpreter.AccountsMetadata), + Balances: make(interpreter.Balances), } opt.fromRaw() @@ -114,11 +115,11 @@ func run(path string) { os.Exit(1) } - result, err := interpreter.RunProgram(parseResult.Value, interpreter.RunProgramOptions{ - Vars: opt.Variables, - Store: opt.Balances, - Meta: opt.Meta, + result, err := interpreter.RunProgram(context.Background(), parseResult.Value, opt.Variables, interpreter.StaticStore{ + Balances: opt.Balances, + Meta: opt.Meta, }) + if err != nil { rng := err.GetRange() os.Stderr.Write([]byte(err.Error())) @@ -163,7 +164,7 @@ func showPretty(result *interpreter.ExecutionResult) { fmt.Println() fmt.Println(ansi.ColorCyan("Meta:")) - txMetaJson, err := json.MarshalIndent(result.TxMeta, "", " ") + txMetaJson, err := json.MarshalIndent(result.Metadata, "", " ") if err != nil { panic(err) } diff --git a/internal/interpreter/batch_balances_query.go b/internal/interpreter/batch_balances_query.go new file mode 100644 index 0000000..2a07ed9 --- /dev/null +++ b/internal/interpreter/batch_balances_query.go @@ -0,0 +1,142 @@ +package interpreter + +import ( + "slices" + + "github.com/formancehq/numscript/internal/parser" + "github.com/formancehq/numscript/internal/utils" + "golang.org/x/exp/maps" +) + +// traverse the script to batch in advance required balance queries + +func (st *programState) findBalancesQueriesInStatement(statement parser.Statement) InterpreterError { + switch statement := statement.(type) { + case *parser.FnCall: + return nil + + case *parser.SendStatement: + // set the current asset + switch sentValue := statement.SentValue.(type) { + case *parser.SentValueAll: + asset, err := evaluateLitExpecting(st, sentValue.Asset, expectAsset) + if err != nil { + return err + } + st.CurrentAsset = *asset + + case *parser.SentValueLiteral: + monetary, err := evaluateLitExpecting(st, sentValue.Monetary, expectMonetary) + if err != nil { + return err + } + st.CurrentAsset = string(monetary.Asset) + + default: + utils.NonExhaustiveMatchPanic[any](sentValue) + } + + // traverse source + return st.findBalancesQueries(statement.Source) + + default: + utils.NonExhaustiveMatchPanic[any](statement) + return nil + } +} + +func (st *programState) batchQuery(account string, asset string) { + if account == "world" { + return + } + + previousValues := st.CurrentBalanceQuery[account] + if !slices.Contains[[]string, string](previousValues, account) { + st.CurrentBalanceQuery[account] = append(previousValues, asset) + } +} + +func (st *programState) runBalancesQuery() error { + filteredQuery := BalanceQuery{} + for accountName, queriedCurrencies := range st.CurrentBalanceQuery { + + cachedCurrenciesForAccount := defaultMapGet(st.CachedBalances, accountName, func() AccountBalance { + return AccountBalance{} + }) + + for _, queriedCurrency := range queriedCurrencies { + isAlreadyCached := slices.Contains(maps.Keys(cachedCurrenciesForAccount), queriedCurrency) + if !isAlreadyCached { + filteredQuery[accountName] = queriedCurrencies + } + } + + } + + // avoid updating balances if we don't need to fetch new data + if len(filteredQuery) == 0 { + return nil + } + + balances, err := st.Store.GetBalances(st.ctx, filteredQuery) + if err != nil { + return err + } + // reset batch query + st.CurrentBalanceQuery = BalanceQuery{} + + st.CachedBalances = balances + return nil +} + +func (st *programState) findBalancesQueries(source parser.Source) InterpreterError { + switch source := source.(type) { + case *parser.SourceAccount: + account, err := evaluateLitExpecting(st, source.Literal, expectAccount) + if err != nil { + return err + } + + st.batchQuery(*account, st.CurrentAsset) + return nil + + case *parser.SourceOverdraft: + // Skip balance tracking when balance is overdraft + if source.Bounded == nil { + return nil + } + + account, err := evaluateLitExpecting(st, source.Address, expectAccount) + if err != nil { + return err + } + st.batchQuery(*account, st.CurrentAsset) + return nil + + case *parser.SourceInorder: + for _, subSource := range source.Sources { + err := st.findBalancesQueries(subSource) + if err != nil { + return err + } + } + return nil + + case *parser.SourceCapped: + // TODO can this be optimized in some cases? + return st.findBalancesQueries(source.From) + + case *parser.SourceAllotment: + for _, item := range source.Items { + err := st.findBalancesQueries(item.From) + if err != nil { + return err + } + } + return nil + + default: + utils.NonExhaustiveMatchPanic[error](source) + return nil + } +} diff --git a/internal/interpreter/evaluate_lit.go b/internal/interpreter/evaluate_lit.go new file mode 100644 index 0000000..2a9a723 --- /dev/null +++ b/internal/interpreter/evaluate_lit.go @@ -0,0 +1,74 @@ +package interpreter + +import ( + "math/big" + + "github.com/formancehq/numscript/internal/parser" + "github.com/formancehq/numscript/internal/utils" +) + +func (st *programState) evaluateLit(literal parser.Literal) (Value, InterpreterError) { + switch literal := literal.(type) { + case *parser.AssetLiteral: + return Asset(literal.Asset), nil + case *parser.AccountLiteral: + return AccountAddress(literal.Name), nil + case *parser.StringLiteral: + return String(literal.String), nil + case *parser.RatioLiteral: + return Portion(*literal.ToRatio()), nil + case *parser.NumberLiteral: + return MonetaryInt(*big.NewInt(int64(literal.Number))), nil + case *parser.MonetaryLiteral: + asset, err := evaluateLitExpecting(st, literal.Asset, expectAsset) + if err != nil { + return nil, err + } + + amount, err := evaluateLitExpecting(st, literal.Amount, expectNumber) + if err != nil { + return nil, err + } + + return Monetary{Asset: Asset(*asset), Amount: MonetaryInt(*amount)}, nil + + case *parser.VariableLiteral: + value, ok := st.ParsedVars[literal.Name] + if !ok { + return nil, UnboundVariableErr{ + Name: literal.Name, + Range: literal.Range, + } + } + return value, nil + default: + utils.NonExhaustiveMatchPanic[any](literal) + return nil, nil + } +} + +func evaluateLitExpecting[T any](st *programState, literal parser.Literal, expect func(Value, parser.Range) (*T, InterpreterError)) (*T, InterpreterError) { + value, err := st.evaluateLit(literal) + if err != nil { + return nil, err + } + + res, err := expect(value, literal.GetRange()) + if err != nil { + return nil, err + } + + return res, nil +} + +func (st *programState) evaluateLiterals(literals []parser.Literal) ([]Value, InterpreterError) { + var values []Value + for _, argLit := range literals { + value, err := st.evaluateLit(argLit) + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil +} diff --git a/internal/interpreter/interpreter.go b/internal/interpreter/interpreter.go index 1d86ee0..39d60ae 100644 --- a/internal/interpreter/interpreter.go +++ b/internal/interpreter/interpreter.go @@ -1,6 +1,7 @@ package interpreter import ( + "context" "math/big" "strconv" "strings" @@ -10,26 +11,64 @@ import ( "github.com/formancehq/numscript/internal/utils" ) -type StaticStore map[string]map[string]*big.Int -type Metadata map[string]string +type VariablesMap map[string]string + +// For each account, list of the needed assets +type BalanceQuery map[string][]string + +// For each account, list of the needed keys +type MetadataQuery map[string][]string + +type AccountBalance = map[string]*big.Int +type Balances map[string]AccountBalance + +type AccountMetadata = map[string]string +type AccountsMetadata map[string]AccountMetadata + +type Store interface { + GetBalances(context.Context, BalanceQuery) (Balances, error) + GetAccountsMetadata(context.Context, MetadataQuery) (AccountsMetadata, error) +} + +type StaticStore struct { + Balances Balances + Meta AccountsMetadata +} + +func (s StaticStore) GetBalances(context.Context, BalanceQuery) (Balances, error) { + if s.Balances == nil { + s.Balances = Balances{} + } + return s.Balances, nil +} +func (s StaticStore) GetAccountsMetadata(context.Context, MetadataQuery) (AccountsMetadata, error) { + if s.Meta == nil { + s.Meta = AccountsMetadata{} + } + return s.Meta, nil +} type InterpreterError interface { error parser.Ranged } +type Metadata = map[string]Value + type ExecutionResult struct { - Postings []Posting `json:"postings"` - TxMeta map[string]Value `json:"txMeta"` - AccountsMeta map[string]Metadata `json:"accountsMeta"` + Postings []Posting `json:"postings"` + + Metadata Metadata `json:"txMeta"` + + AccountsMetadata AccountsMetadata `json:"accountsMeta"` } -func parsePercentage(p string) big.Rat { +func parsePercentage(p string) *big.Rat { num, den, err := parser.ParsePercentageRatio(p) if err != nil { panic(err) } - return *big.NewRat(int64(num), int64(den)) + return new(big.Rat).SetFrac(num, den) } func parseMonetary(source string) (Monetary, InterpreterError) { @@ -61,7 +100,7 @@ func parseVar(type_ string, rawValue string, r parser.Range) (Value, Interpreter case analysis.TypeAccount: return AccountAddress(rawValue), nil case analysis.TypePortion: - return Portion(parsePercentage(rawValue)), nil + return Portion(*parsePercentage(rawValue)), nil case analysis.TypeAsset: return Asset(rawValue), nil case analysis.TypeNumber: @@ -124,50 +163,54 @@ func (s *programState) parseVars(varDeclrs []parser.VarDeclaration, rawVars map[ if err != nil { return err } - s.Vars[varsDecl.Name.Name] = parsed + s.ParsedVars[varsDecl.Name.Name] = parsed } else { value, err := s.handleOrigin(varsDecl.Type.Name, *varsDecl.Origin) if err != nil { return err } - s.Vars[varsDecl.Name.Name] = value + s.ParsedVars[varsDecl.Name.Name] = value } } return nil } -type RunProgramOptions struct { - Vars map[string]string - Store StaticStore - Meta map[string]Metadata -} - func RunProgram( + ctx context.Context, program parser.Program, - options RunProgramOptions, + vars map[string]string, + store Store, ) (*ExecutionResult, InterpreterError) { - if options.Vars == nil { - options.Vars = make(map[string]string) - } - if options.Store == nil { - options.Store = make(StaticStore) - } - if options.Meta == nil { - options.Meta = make(map[string]Metadata) - } - st := programState{ - Vars: make(map[string]Value), - TxMeta: make(map[string]Value), - Store: options.Store, - Meta: options.Meta, + ParsedVars: make(map[string]Value), + TxMeta: make(map[string]Value), + CachedAccountsMeta: AccountsMetadata{}, + CachedBalances: Balances{}, + SetAccountsMeta: AccountsMetadata{}, + Store: store, + + CurrentBalanceQuery: BalanceQuery{}, + ctx: ctx, } - err := st.parseVars(program.Vars, options.Vars) + err := st.parseVars(program.Vars, vars) if err != nil { return nil, err } + // preload balances before executing the script + for _, statement := range program.Statements { + err := st.findBalancesQueriesInStatement(statement) + if err != nil { + return nil, err + } + } + + genericErr := st.runBalancesQuery() + if genericErr != nil { + return nil, QueryBalanceError{WrappedError: genericErr} + } + postings := make([]Posting, 0) for _, statement := range program.Statements { statementPostings, err := st.runStatement(statement) @@ -178,91 +221,34 @@ func RunProgram( } res := &ExecutionResult{ - Postings: postings, - TxMeta: st.TxMeta, - AccountsMeta: st.Meta, // TODO clone the map + Postings: postings, + Metadata: st.TxMeta, + AccountsMetadata: st.SetAccountsMeta, } return res, nil } type programState struct { + ctx context.Context + // Asset of the send statement currently being executed. // // it's value is undefined outside of send statements execution CurrentAsset string - Vars map[string]Value - TxMeta map[string]Value - Store StaticStore - Senders []Sender - Receivers []Receiver - Meta map[string]Metadata -} - -func (st *programState) evaluateLit(literal parser.Literal) (Value, InterpreterError) { - switch literal := literal.(type) { - case *parser.AssetLiteral: - return Asset(literal.Asset), nil - case *parser.AccountLiteral: - return AccountAddress(literal.Name), nil - case *parser.StringLiteral: - return String(literal.String), nil - case *parser.RatioLiteral: - return Portion(*literal.ToRatio()), nil - case *parser.NumberLiteral: - return MonetaryInt(*big.NewInt(int64(literal.Number))), nil - case *parser.MonetaryLiteral: - asset, err := evaluateLitExpecting(st, literal.Asset, expectAsset) - if err != nil { - return nil, err - } - - amount, err := evaluateLitExpecting(st, literal.Amount, expectNumber) - if err != nil { - return nil, err - } - - return Monetary{Asset: Asset(*asset), Amount: MonetaryInt(*amount)}, nil + ParsedVars map[string]Value + TxMeta map[string]Value + Senders []Sender + Receivers []Receiver - case *parser.VariableLiteral: - value, ok := st.Vars[literal.Name] - if !ok { - return nil, UnboundVariableErr{ - Name: literal.Name, - Range: literal.Range, - } - } - return value, nil - default: - utils.NonExhaustiveMatchPanic[any](literal) - return nil, nil - } -} + Store Store -func evaluateLitExpecting[T any](st *programState, literal parser.Literal, expect func(Value, parser.Range) (*T, InterpreterError)) (*T, InterpreterError) { - value, err := st.evaluateLit(literal) - if err != nil { - return nil, err - } + SetAccountsMeta AccountsMetadata - res, err := expect(value, literal.GetRange()) - if err != nil { - return nil, err - } + CachedAccountsMeta AccountsMetadata + CachedBalances Balances - return res, nil -} - -func (st *programState) evaluateLiterals(literals []parser.Literal) ([]Value, InterpreterError) { - var values []Value - for _, argLit := range literals { - value, err := st.evaluateLit(argLit) - if err != nil { - return nil, err - } - values = append(values, value) - } - return values, nil + CurrentBalanceQuery BalanceQuery } func (st *programState) runStatement(statement parser.Statement) ([]Posting, InterpreterError) { @@ -307,10 +293,10 @@ func (st *programState) getPostings() ([]Posting, InterpreterError) { } for _, posting := range postings { - srcBalance := st.getBalance(posting.Source, posting.Asset) + srcBalance := st.getCachedBalance(posting.Source, posting.Asset) srcBalance.Sub(srcBalance, posting.Amount) - destBalance := st.getBalance(posting.Destination, posting.Asset) + destBalance := st.getCachedBalance(posting.Destination, posting.Asset) destBalance.Add(destBalance, posting.Amount) } return postings, nil @@ -341,7 +327,7 @@ func (st *programState) runSendStatement(statement parser.SendStatement) ([]Post } st.CurrentAsset = string(monetary.Asset) - monetaryAmt := big.Int(monetary.Amount) + monetaryAmt := (*big.Int)(&monetary.Amount) if monetaryAmt.Cmp(big.NewInt(0)) == -1 { return nil, NegativeAmountErr{Amount: monetary.Amount} } @@ -366,20 +352,13 @@ func (st *programState) runSendStatement(statement parser.SendStatement) ([]Post } -func (s *programState) getBalance(account string, asset string) *big.Int { - balance, ok := s.Store[account] - if !ok { - m := make(map[string]*big.Int) - s.Store[account] = m - balance = m - } - - assetBalance, ok := balance[asset] - if !ok { - zero := big.NewInt(0) - balance[asset] = zero - assetBalance = zero - } +func (s *programState) getCachedBalance(account string, asset string) *big.Int { + balance := defaultMapGet(s.CachedBalances, account, func() AccountBalance { + return AccountBalance{} + }) + assetBalance := defaultMapGet(balance, asset, func() *big.Int { + return big.NewInt(0) + }) return assetBalance } @@ -395,17 +374,15 @@ func (s *programState) sendAllToAccount(accountLiteral parser.Literal, ovedraft } } - balance := s.getBalance(*account, s.CurrentAsset) + balance := s.getCachedBalance(*account, s.CurrentAsset) // we sent balance+overdraft - var sentAmt big.Int - sentAmt.Add(balance, ovedraft) - + sentAmt := new(big.Int).Add(balance, ovedraft) s.Senders = append(s.Senders, Sender{ Name: *account, - Monetary: &sentAmt, + Monetary: sentAmt, }) - return &sentAmt, nil + return sentAmt, nil } // Send as much as possible (and return the sent amt) @@ -443,7 +420,7 @@ func (s *programState) sendAll(source parser.Source) (*big.Int, InterpreterError } // We switch to the default sending evaluation for this subsource - return s.trySendingUpTo(source.From, *monetary) + return s.trySendingUpTo(source.From, monetary) case *parser.SourceAllotment: return nil, InvalidAllotmentInSendAll{} @@ -455,15 +432,15 @@ func (s *programState) sendAll(source parser.Source) (*big.Int, InterpreterError } // Fails if it doesn't manage to send exactly "amount" -func (s *programState) trySendingExact(source parser.Source, amount big.Int) InterpreterError { +func (s *programState) trySendingExact(source parser.Source, amount *big.Int) InterpreterError { sentAmt, err := s.trySendingUpTo(source, amount) if err != nil { return err } - if sentAmt.Cmp(&amount) != 0 { + if sentAmt.Cmp(amount) != 0 { return MissingFundsErr{ Asset: s.CurrentAsset, - Needed: amount, + Needed: *amount, Available: *sentAmt, Range: source.GetRange(), } @@ -471,7 +448,7 @@ func (s *programState) trySendingExact(source parser.Source, amount big.Int) Int return nil } -func (s *programState) trySendingToAccount(accountLiteral parser.Literal, amount big.Int, overdraft *big.Int) (*big.Int, InterpreterError) { +func (s *programState) trySendingToAccount(accountLiteral parser.Literal, amount *big.Int, overdraft *big.Int) (*big.Int, InterpreterError) { account, err := evaluateLitExpecting(s, accountLiteral, expectAccount) if err != nil { return nil, err @@ -480,30 +457,28 @@ func (s *programState) trySendingToAccount(accountLiteral parser.Literal, amount overdraft = nil } - var actuallySentAmt big.Int + var actuallySentAmt *big.Int if overdraft == nil { // unbounded overdraft: we send the required amount - actuallySentAmt.Set(&amount) + actuallySentAmt = new(big.Int).Set(amount) } else { - balance := s.getBalance(*account, s.CurrentAsset) + balance := s.getCachedBalance(*account, s.CurrentAsset) // that's the amount we are allowed to send (balance + overdraft) - var safeSendAmt big.Int - safeSendAmt.Add(balance, overdraft) - - actuallySentAmt = *utils.MinBigInt(&safeSendAmt, &amount) + safeSendAmt := new(big.Int).Add(balance, overdraft) + actuallySentAmt = utils.MinBigInt(safeSendAmt, amount) } s.Senders = append(s.Senders, Sender{ Name: *account, - Monetary: &actuallySentAmt, + Monetary: actuallySentAmt, }) - return &actuallySentAmt, nil + return actuallySentAmt, nil } // Tries sending "amount" and returns the actually sent amt. // Doesn't fail (unless nested sources fail) -func (s *programState) trySendingUpTo(source parser.Source, amount big.Int) (*big.Int, InterpreterError) { +func (s *programState) trySendingUpTo(source parser.Source, amount *big.Int) (*big.Int, InterpreterError) { switch source := source.(type) { case *parser.SourceAccount: return s.trySendingToAccount(source.Literal, amount, big.NewInt(0)) @@ -520,44 +495,40 @@ func (s *programState) trySendingUpTo(source parser.Source, amount big.Int) (*bi return s.trySendingToAccount(source.Address, amount, cap) case *parser.SourceInorder: - var totalLeft big.Int - totalLeft.Set(&amount) + totalLeft := new(big.Int).Set(amount) for _, source := range source.Sources { sentAmt, err := s.trySendingUpTo(source, totalLeft) if err != nil { return nil, err } - totalLeft.Sub(&totalLeft, sentAmt) + totalLeft.Sub(totalLeft, sentAmt) } - - var sentAmt big.Int - sentAmt.Sub(&amount, &totalLeft) - return &sentAmt, nil + return new(big.Int).Sub(amount, totalLeft), nil case *parser.SourceAllotment: var items []parser.AllotmentValue for _, i := range source.Items { items = append(items, i.Allotment) } - allot, err := s.makeAllotment(amount.Int64(), items) + allot, err := s.makeAllotment(amount, items) if err != nil { return nil, err } for i, allotmentItem := range source.Items { - err := s.trySendingExact(allotmentItem.From, *big.NewInt(allot[i])) + err := s.trySendingExact(allotmentItem.From, allot[i]) if err != nil { return nil, err } } - return &amount, nil + return amount, nil case *parser.SourceCapped: cap, err := evaluateLitExpecting(s, source.Cap, expectMonetaryOfAsset(s.CurrentAsset)) if err != nil { return nil, err } - cappedAmount := utils.MinBigInt(&amount, cap) - return s.trySendingUpTo(source.From, *cappedAmount) + cappedAmount := utils.MinBigInt(amount, cap) + return s.trySendingUpTo(source.From, cappedAmount) default: utils.NonExhaustiveMatchPanic[any](source) @@ -586,14 +557,14 @@ func (s *programState) receiveFrom(destination parser.Destination, amount *big.I items = append(items, i.Allotment) } - allot, err := s.makeAllotment(amount.Int64(), items) + allot, err := s.makeAllotment(amount, items) if err != nil { return err } receivedTotal := big.NewInt(0) for i, allotmentItem := range destination.Items { - amtToReceive := big.NewInt(allot[i]) + amtToReceive := allot[i] err := s.receiveFromKeptOrDest(allotmentItem.To, amtToReceive) if err != nil { return err @@ -604,15 +575,14 @@ func (s *programState) receiveFrom(destination parser.Destination, amount *big.I return nil case *parser.DestinationInorder: - var remainingAmount big.Int - remainingAmount.Set(amount) + remainingAmount := new(big.Int).Set(amount) - handler := func(keptOrDest parser.KeptOrDestination, amountToReceive big.Int) InterpreterError { - err := s.receiveFromKeptOrDest(keptOrDest, &amountToReceive) + handler := func(keptOrDest parser.KeptOrDestination, amountToReceive *big.Int) InterpreterError { + err := s.receiveFromKeptOrDest(keptOrDest, amountToReceive) if err != nil { return err } - remainingAmount.Sub(&remainingAmount, &amountToReceive) + remainingAmount.Sub(remainingAmount, amountToReceive) return err } @@ -628,16 +598,16 @@ func (s *programState) receiveFrom(destination parser.Destination, amount *big.I break } - err = handler(destinationClause.To, *utils.MinBigInt(cap, &remainingAmount)) + err = handler(destinationClause.To, utils.MinBigInt(cap, remainingAmount)) if err != nil { return err } } - var cp big.Int // if remainingAmount bad things with pointers happen.. somehow - cp.Set(&remainingAmount) - return handler(destination.Remaining, cp) + remainingAmountCopy := new(big.Int).Set(remainingAmount) + // passing "remainingAmount" directly breaks the code + return handler(destination.Remaining, remainingAmountCopy) default: utils.NonExhaustiveMatchPanic[any](destination) @@ -664,19 +634,18 @@ func (s *programState) receiveFromKeptOrDest(keptOrDest parser.KeptOrDestination } -func (s *programState) makeAllotment(monetary int64, items []parser.AllotmentValue) ([]int64, InterpreterError) { - // TODO runtime error when totalAllotment != 1? +func (s *programState) makeAllotment(monetary *big.Int, items []parser.AllotmentValue) ([]*big.Int, InterpreterError) { totalAllotment := big.NewRat(0, 1) - var allotments []big.Rat + var allotments []*big.Rat remainingAllotmentIndex := -1 for i, item := range items { switch allotment := item.(type) { case *parser.RatioLiteral: - rat := big.NewRat(int64(allotment.Numerator), int64(allotment.Denominator)) + rat := allotment.ToRatio() totalAllotment.Add(totalAllotment, rat) - allotments = append(allotments, *rat) + allotments = append(allotments, rat) case *parser.VariableLiteral: rat, err := evaluateLitExpecting(s, allotment, expectPortion) if err != nil { @@ -684,45 +653,44 @@ func (s *programState) makeAllotment(monetary int64, items []parser.AllotmentVal } totalAllotment.Add(totalAllotment, rat) - allotments = append(allotments, *rat) + allotments = append(allotments, rat) case *parser.RemainingAllotment: remainingAllotmentIndex = i - var rat big.Rat - allotments = append(allotments, rat) + allotments = append(allotments, new(big.Rat)) // TODO check there are not duplicate remaining clause } } if remainingAllotmentIndex != -1 { - var rat big.Rat - rat.Sub(big.NewRat(1, 1), totalAllotment) - allotments[remainingAllotmentIndex] = rat + allotments[remainingAllotmentIndex] = new(big.Rat).Sub(big.NewRat(1, 1), totalAllotment) } else if totalAllotment.Cmp(big.NewRat(1, 1)) != 0 { return nil, InvalidAllotmentSum{ActualSum: *totalAllotment} } - parts := make([]int64, len(allotments)) + parts := make([]*big.Int, len(allotments)) - var totalAllocated int64 + totalAllocated := big.NewInt(0) for i, allot := range allotments { - var product big.Rat - product.Mul(&allot, big.NewRat(monetary, 1)) + monetaryRat := new(big.Rat).SetInt(monetary) + product := new(big.Rat).Mul(allot, monetaryRat) - floored := product.Num().Int64() / product.Denom().Int64() + floored := new(big.Int).Div(product.Num(), product.Denom()) parts[i] = floored - totalAllocated += floored + totalAllocated.Add(totalAllocated, floored) + } for i := range parts { - if totalAllocated >= monetary { + if /* totalAllocated >= monetary */ totalAllocated.Cmp(monetary) != -1 { break } - parts[i]++ - totalAllocated++ + parts[i].Add(parts[i], big.NewInt(1)) + // totalAllocated++ + totalAllocated.Add(totalAllocated, big.NewInt(1)) } return parts, nil @@ -743,8 +711,16 @@ func meta( return "", err } + meta, fetchMetaErr := s.Store.GetAccountsMetadata(s.ctx, MetadataQuery{ + *account: []string{*key}, + }) + if fetchMetaErr != nil { + return "", QueryMetadataError{WrappedError: fetchMetaErr} + } + s.CachedAccountsMeta = meta + // body - accountMeta := s.Meta[*account] + accountMeta := s.CachedAccountsMeta[*account] value, ok := accountMeta[*key] if !ok { @@ -769,7 +745,13 @@ func balance( } // body - balance := s.getBalance(*account, *asset) + s.batchQuery(*account, *asset) + fetchBalanceErr := s.runBalancesQuery() + if fetchBalanceErr != nil { + return nil, QueryBalanceError{WrappedError: fetchBalanceErr} + } + + balance := s.getCachedBalance(*account, *asset) if balance.Cmp(big.NewInt(0)) == -1 { return nil, NegativeBalanceError{ Account: *account, @@ -777,12 +759,11 @@ func balance( } } - var balanceCopy big.Int - balanceCopy.Set(balance) + balanceCopy := new(big.Int).Set(balance) m := Monetary{ Asset: Asset(*asset), - Amount: MonetaryInt(balanceCopy), + Amount: MonetaryInt(*balanceCopy), } return &m, nil } @@ -810,21 +791,11 @@ func setAccountMeta(st *programState, r parser.Range, args []Value) InterpreterE return err } - accountMeta := defaultMapGet(st.Meta, *account, func() Metadata { - return make(Metadata) + accountMeta := defaultMapGet(st.SetAccountsMeta, *account, func() AccountMetadata { + return AccountMetadata{} }) accountMeta[*key] = (*meta).String() return nil } - -func defaultMapGet[T any](m map[string]T, key string, getDefault func() T) T { - lookup, ok := m[key] - if !ok { - default_ := getDefault() - m[key] = default_ - return default_ - } - return lookup -} diff --git a/internal/interpreter/interpreter_error.go b/internal/interpreter/interpreter_error.go index e94c69a..88811b1 100644 --- a/internal/interpreter/interpreter_error.go +++ b/internal/interpreter/interpreter_error.go @@ -156,3 +156,21 @@ type InvalidAllotmentSum struct { func (e InvalidAllotmentSum) Error() string { return fmt.Sprintf("Invalid allotment: portions sum should be 1 (got %s instead)", e.ActualSum.String()) } + +type QueryBalanceError struct { + parser.Range + WrappedError error +} + +func (e QueryBalanceError) Error() string { + return e.WrappedError.Error() +} + +type QueryMetadataError struct { + parser.Range + WrappedError error +} + +func (e QueryMetadataError) Error() string { + return e.WrappedError.Error() +} diff --git a/internal/interpreter/interpreter_errors_test.go b/internal/interpreter/interpreter_errors_test.go index 2a76710..cbf07cf 100644 --- a/internal/interpreter/interpreter_errors_test.go +++ b/internal/interpreter/interpreter_errors_test.go @@ -1,6 +1,7 @@ package interpreter_test import ( + "context" "testing" "github.com/formancehq/numscript/internal/interpreter" @@ -9,9 +10,9 @@ import ( "github.com/stretchr/testify/require" ) -func matchErrWithSnapshots(t *testing.T, src string, runOpt interpreter.RunProgramOptions) { +func matchErrWithSnapshots(t *testing.T, src string, vars map[string]string, runOpt interpreter.StaticStore) { parsed := parser.Parse(src) - _, err := interpreter.RunProgram(parsed.Value, runOpt) + _, err := interpreter.RunProgram(context.Background(), parsed.Value, vars, runOpt) require.NotNil(t, err) snaps.MatchSnapshot(t, err.GetRange().ShowOnSource(parsed.Source)) } @@ -20,7 +21,7 @@ func TestShowUnboundVar(t *testing.T) { matchErrWithSnapshots(t, `send [COIN 10] ( source = $unbound_var destination = @dest -)`, interpreter.RunProgramOptions{}) +)`, nil, interpreter.StaticStore{}) } @@ -28,7 +29,7 @@ func TestShowMissingFundsSingleAccount(t *testing.T) { matchErrWithSnapshots(t, `send [COIN 10] ( source = @a destination = @dest -)`, interpreter.RunProgramOptions{}) +)`, nil, interpreter.StaticStore{}) } @@ -39,7 +40,7 @@ func TestShowMissingFundsInorder(t *testing.T) { @b } destination = @dest -)`, interpreter.RunProgramOptions{}) +)`, nil, interpreter.StaticStore{}) } func TestShowMissingFundsAllotment(t *testing.T) { @@ -49,7 +50,7 @@ func TestShowMissingFundsAllotment(t *testing.T) { remaining from @world } destination = @dest -)`, interpreter.RunProgramOptions{}) +)`, nil, interpreter.StaticStore{}) } func TestShowMissingFundsMax(t *testing.T) { @@ -59,14 +60,14 @@ func TestShowMissingFundsMax(t *testing.T) { remaining from @world } destination = @dest -)`, interpreter.RunProgramOptions{}) +)`, nil, interpreter.StaticStore{}) } func TestShowMetadataNotFound(t *testing.T) { matchErrWithSnapshots(t, `vars { number $my_var = meta(@acc, "key") } -`, interpreter.RunProgramOptions{}) +`, nil, interpreter.StaticStore{}) } func TestShowTypeError(t *testing.T) { @@ -74,14 +75,14 @@ func TestShowTypeError(t *testing.T) { source = @a destination = @b ) -`, interpreter.RunProgramOptions{}) +`, nil, interpreter.StaticStore{}) } func TestShowInvalidTypeErr(t *testing.T) { matchErrWithSnapshots(t, `vars { invalid_t $x } -`, interpreter.RunProgramOptions{ - Vars: map[string]string{"x": "42"}, - }) +`, map[string]string{"x": "42"}, + interpreter.StaticStore{}, + ) } diff --git a/internal/interpreter/interpreter_test.go b/internal/interpreter/interpreter_test.go index e1cafe9..8b8dba5 100644 --- a/internal/interpreter/interpreter_test.go +++ b/internal/interpreter/interpreter_test.go @@ -1,6 +1,7 @@ package interpreter_test import ( + "context" "encoding/json" "math/big" @@ -18,7 +19,7 @@ type TestCase struct { source string program *parser.Program vars map[string]string - meta map[string]machine.Metadata + meta machine.AccountsMetadata balances map[string]map[string]*big.Int expected CaseResult } @@ -26,12 +27,12 @@ type TestCase struct { func NewTestCase() TestCase { return TestCase{ vars: make(map[string]string), - meta: make(map[string]machine.Metadata), + meta: machine.AccountsMetadata{}, balances: make(map[string]map[string]*big.Int), expected: CaseResult{ - Postings: []machine.Posting{}, - Metadata: make(map[string]machine.Value), - Error: nil, + Postings: []machine.Posting{}, + TxMetadata: make(map[string]machine.Value), + Error: nil, }, } } @@ -83,8 +84,7 @@ func test(t *testing.T, testCase TestCase) { require.NotNil(t, prog) - execResult, err := machine.RunProgram(*prog, machine.RunProgramOptions{ - testCase.vars, + execResult, err := machine.RunProgram(context.Background(), *prog, testCase.vars, machine.StaticStore{ testCase.balances, testCase.meta, }) @@ -102,22 +102,22 @@ func test(t *testing.T, testCase TestCase) { if expected.Postings == nil { expected.Postings = make([]Posting, 0) } - if expected.Metadata == nil { - expected.Metadata = make(map[string]machine.Value) + if expected.TxMetadata == nil { + expected.TxMetadata = make(map[string]machine.Value) } if expected.AccountMetadata == nil { - expected.AccountMetadata = make(map[string]machine.Metadata) + expected.AccountMetadata = machine.AccountsMetadata{} } assert.Equal(t, expected.Postings, execResult.Postings) - assert.Equal(t, expected.Metadata, execResult.TxMeta) - assert.Equal(t, expected.AccountMetadata, execResult.AccountsMeta) + assert.Equal(t, expected.TxMetadata, execResult.Metadata) + assert.Equal(t, expected.AccountMetadata, execResult.AccountsMetadata) } type CaseResult struct { Postings []machine.Posting - Metadata map[string]machine.Value - AccountMetadata map[string]machine.Metadata + TxMetadata map[string]machine.Value + AccountMetadata machine.AccountsMetadata Error machine.InterpreterError } @@ -155,7 +155,7 @@ func TestSetTxMeta(t *testing.T) { `) tc.expected = CaseResult{ - Metadata: map[string]machine.Value{ + TxMetadata: map[string]machine.Value{ "num": machine.NewMonetaryInt(42), "str": machine.String("abc"), "asset": machine.Asset("COIN"), @@ -179,7 +179,7 @@ func TestSetAccountMeta(t *testing.T) { `) tc.expected = CaseResult{ - AccountMetadata: map[string]machine.Metadata{ + AccountMetadata: machine.AccountsMetadata{ "acc": { "num": "42", "str": "abc", @@ -196,7 +196,7 @@ func TestSetAccountMeta(t *testing.T) { func TestOverrideAccountMeta(t *testing.T) { tc := NewTestCase() - tc.meta = map[string]machine.Metadata{ + tc.meta = machine.AccountsMetadata{ "acc": { "initial": "0", "overridden": "1", @@ -207,9 +207,8 @@ func TestOverrideAccountMeta(t *testing.T) { set_account_meta(@acc, "new", 2) `) tc.expected = CaseResult{ - AccountMetadata: map[string]machine.Metadata{ + AccountMetadata: machine.AccountsMetadata{ "acc": { - "initial": "0", "overridden": "100", "new": "2", }, @@ -251,7 +250,7 @@ func TestVariables(t *testing.T) { Destination: "users:002", }, }, - Metadata: map[string]machine.Value{ + TxMetadata: map[string]machine.Value{ "description": machine.String("midnight ride"), "ride": machine.NewMonetaryInt(1), }, @@ -296,7 +295,7 @@ func TestVariablesJSON(t *testing.T) { Destination: "users:002", }, }, - Metadata: map[string]machine.Value{ + TxMetadata: map[string]machine.Value{ "description": machine.String("midnight ride"), "ride": machine.NewMonetaryInt(1), "por": machine.Portion(*big.NewRat(42, 100)), @@ -999,7 +998,7 @@ func TestMetadata(t *testing.T) { tc.setVarsFromJSON(t, `{ "sale": "sales:042" }`) - tc.meta = map[string]machine.Metadata{ + tc.meta = machine.AccountsMetadata{ "sales:042": { "seller": "users:053", }, @@ -1024,15 +1023,6 @@ func TestMetadata(t *testing.T) { Destination: "platform", }, }, - // Keep the original metadata - AccountMetadata: map[string]machine.Metadata{ - "sales:042": { - "seller": "users:053", - }, - "users:053": { - "commission": "12.5%", - }, - }, Error: nil, } test(t, tc) @@ -2667,3 +2657,29 @@ func TestCascadingSources(t *testing.T) { } test(t, tc) } + +func TestUseBalanceTwice(t *testing.T) { + tc := NewTestCase() + tc.compile(t, ` + vars { monetary $v = balance(@src, COIN) } + + send $v ( + source = @src + destination = @dest + )`) + + tc.setBalance("src", "COIN", 50) + tc.expected = CaseResult{ + + Postings: []Posting{ + { + Asset: "COIN", + Amount: big.NewInt(50), + Source: "src", + Destination: "dest", + }, + }, + Error: nil, + } + test(t, tc) +} diff --git a/internal/interpreter/utils.go b/internal/interpreter/utils.go new file mode 100644 index 0000000..a32bf66 --- /dev/null +++ b/internal/interpreter/utils.go @@ -0,0 +1,11 @@ +package interpreter + +func defaultMapGet[T any](m map[string]T, key string, getDefault func() T) T { + lookup, ok := m[key] + if !ok { + default_ := getDefault() + m[key] = default_ + return default_ + } + return lookup +} diff --git a/main.go b/internal/numscript/numscript.go similarity index 100% rename from main.go rename to internal/numscript/numscript.go diff --git a/internal/parser/__snapshots__/parser_test.snap b/internal/parser/__snapshots__/parser_test.snap index 929cf57..f67761c 100755 --- a/internal/parser/__snapshots__/parser_test.snap +++ b/internal/parser/__snapshots__/parser_test.snap @@ -153,8 +153,14 @@ parser.Program{ Start: parser.Position{Character:13, Line:1}, End: parser.Position{Character:16, Line:1}, }, - Numerator: 0x1, - Denominator: 0x3, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x3}, + }, }, From: &parser.SourceAccount{ Literal: &parser.AccountLiteral{ @@ -229,8 +235,14 @@ parser.Program{ Start: parser.Position{Character:4, Line:2}, End: parser.Position{Character:7, Line:2}, }, - Numerator: 0x2a, - Denominator: 0x64, + Numerator: &big.Int{ + neg: false, + abs: {0x2a}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x64}, + }, }, From: &parser.SourceAccount{ Literal: &parser.AccountLiteral{ @@ -252,8 +264,14 @@ parser.Program{ Start: parser.Position{Character:1, Line:3}, End: parser.Position{Character:4, Line:3}, }, - Numerator: 0x1, - Denominator: 0x2, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2}, + }, }, From: &parser.SourceAccount{ Literal: &parser.AccountLiteral{ @@ -349,8 +367,14 @@ parser.Program{ Start: parser.Position{Character:13, Line:1}, End: parser.Position{Character:18, Line:1}, }, - Numerator: 0xf2, - Denominator: 0x2710, + Numerator: &big.Int{ + neg: false, + abs: {0xf2}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2710}, + }, }, From: &parser.SourceAccount{ Literal: &parser.AccountLiteral{ @@ -434,8 +458,14 @@ parser.Program{ Start: parser.Position{Character:18, Line:2}, End: parser.Position{Character:21, Line:2}, }, - Numerator: 0x1, - Denominator: 0x2, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2}, + }, }, To: &parser.DestinationTo{ Destination: &parser.DestinationAccount{ @@ -938,8 +968,14 @@ parser.Program{ Start: parser.Position{Character:3, Line:3}, End: parser.Position{Character:6, Line:3}, }, - Numerator: 0x1, - Denominator: 0x2, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2}, + }, }, To: &parser.DestinationTo{ Destination: &parser.DestinationAccount{ @@ -1316,8 +1352,14 @@ parser.Program{ Start: parser.Position{Character:1, Line:4}, End: parser.Position{Character:4, Line:4}, }, - Numerator: 0x1, - Denominator: 0x2, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2}, + }, }, &parser.VariableLiteral{ Range: parser.Range{ @@ -1675,8 +1717,14 @@ parser.Program{ Start: parser.Position{Character:18, Line:2}, End: parser.Position{Character:21, Line:2}, }, - Numerator: 0x1, - Denominator: 0x2, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x2}, + }, }, To: &parser.DestinationKept{ Range: parser.Range{ @@ -1735,8 +1783,14 @@ parser.Program{ Start: parser.Position{Character:4, Line:4}, End: parser.Position{Character:9, Line:4}, }, - Numerator: 0x1, - Denominator: 0x6, + Numerator: &big.Int{ + neg: false, + abs: {0x1}, + }, + Denominator: &big.Int{ + neg: false, + abs: {0x6}, + }, }, To: &parser.DestinationTo{ Destination: &parser.DestinationAccount{ diff --git a/internal/parser/ast.go b/internal/parser/ast.go index 7ff825d..99132be 100644 --- a/internal/parser/ast.go +++ b/internal/parser/ast.go @@ -55,8 +55,8 @@ type ( RatioLiteral struct { Range Range - Numerator uint64 - Denominator uint64 + Numerator *big.Int + Denominator *big.Int } VariableLiteral struct { @@ -66,7 +66,7 @@ type ( ) func (r RatioLiteral) ToRatio() *big.Rat { - return big.NewRat(int64(r.Numerator), int64(r.Denominator)) + return new(big.Rat).SetFrac(r.Numerator, r.Denominator) } type RemainingAllotment struct { diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 1b7fb79..fdc449c 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -2,6 +2,7 @@ package parser import ( "math" + "math/big" "strconv" "strings" @@ -16,9 +17,9 @@ type ParserError struct { Msg string } -type ParseResult[T any] struct { +type ParseResult struct { Source string - Value T + Value Program Errors []ParserError } @@ -43,7 +44,7 @@ func (l *ErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol }) } -func Parse(input string) ParseResult[Program] { +func Parse(input string) ParseResult { // TODO handle lexer errors listener := &ErrorListener{} @@ -60,7 +61,7 @@ func Parse(input string) ParseResult[Program] { parsed := parseProgram(parser.Program()) - return ParseResult[Program]{ + return ParseResult{ Source: input, Value: parsed, Errors: listener.Errors, @@ -230,18 +231,20 @@ func parseSource(sourceCtx parser.ISourceContext) Source { } } +func unsafeParseBigInt(s string) *big.Int { + s = strings.TrimSpace(s) + i, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("invalid int: " + s) + } + return i +} + func parseRatio(source string, range_ Range) *RatioLiteral { split := strings.Split(source, "/") - num, err := strconv.ParseUint(strings.TrimSpace(split[0]), 0, 64) - if err != nil { - panic(err) - } - - den, err := strconv.ParseUint(strings.TrimSpace(split[1]), 0, 64) - if err != nil { - panic(err) - } + num := unsafeParseBigInt(split[0]) + den := unsafeParseBigInt(split[1]) return &RatioLiteral{ Range: range_, @@ -250,11 +253,12 @@ func parseRatio(source string, range_ Range) *RatioLiteral { } } -func ParsePercentageRatio(source string) (uint64, uint64, error) { +// TODO actually handle big int +func ParsePercentageRatio(source string) (*big.Int, *big.Int, error) { str := strings.TrimSuffix(source, "%") num, err := strconv.ParseUint(strings.Replace(str, ".", "", -1), 0, 64) if err != nil { - return 0, 0, err + return nil, nil, err } var denominator uint64 @@ -267,7 +271,7 @@ func ParsePercentageRatio(source string) (uint64, uint64, error) { denominator = 100 } - return num, denominator, nil + return big.NewInt(int64(num)), big.NewInt(int64(denominator)), nil } func parsePercentageRatio(source string, range_ Range) *RatioLiteral { diff --git a/numscript.go b/numscript.go new file mode 100644 index 0000000..e1ca324 --- /dev/null +++ b/numscript.go @@ -0,0 +1,60 @@ +package numscript + +import ( + "context" + + "github.com/formancehq/numscript/internal/interpreter" + "github.com/formancehq/numscript/internal/parser" +) + +// This struct represents a parsed numscript source code +type ParseResult struct { + parseResult parser.ParseResult +} + +// ---- TODO useful for the playground +// func (*ParseResult) GetNeededVariables() map[string]ValueType {} +// func (*ParseResult) GetDiagnostics() []Diagnostic {} + +type ParserError = parser.ParserError + +func Parse(code string) ParseResult { + return ParseResult{parseResult: parser.Parse(code)} +} + +var ParseErrorsToString = parser.ParseErrorsToString + +func (p ParseResult) GetParsingErrors() []ParserError { + return p.parseResult.Errors +} + +type ( + VariablesMap = interpreter.VariablesMap + Posting = interpreter.Posting + ExecutionResult = interpreter.ExecutionResult + // For each account, list of the needed assets + BalanceQuery = interpreter.BalanceQuery + MetadataQuery = interpreter.MetadataQuery + AccountBalance = interpreter.AccountBalance + Balances = interpreter.Balances + + AccountMetadata = interpreter.AccountMetadata + + // The newly defined account metadata after the execution + AccountsMetadata = interpreter.AccountsMetadata + + // The transaction metadata, set by set_tx_meta() + Metadata = interpreter.Metadata + + Store = interpreter.Store + + Value = interpreter.Value +) + +func (p ParseResult) Run(ctx context.Context, vars VariablesMap, store Store) (ExecutionResult, error) { + res, err := interpreter.RunProgram(ctx, p.parseResult.Value, vars, store) + if err != nil { + return ExecutionResult{}, err + } + return *res, nil +} diff --git a/numscript_test.go b/numscript_test.go new file mode 100644 index 0000000..7deb4f5 --- /dev/null +++ b/numscript_test.go @@ -0,0 +1,279 @@ +package numscript_test + +import ( + "context" + "errors" + "math/big" + "testing" + + "github.com/formancehq/numscript" + "github.com/formancehq/numscript/internal/interpreter" + "github.com/stretchr/testify/require" +) + +func TestGetBalancesInorder(t *testing.T) { + parseResult := numscript.Parse(`vars { + account $s1 + account $s2 = meta(@account_that_needs_meta, "k") + number $b = balance(@account_that_needs_balance, USD/2) +} + +send [COIN 100] ( + source = { + $s1 + $s2 + @source3 + @world + } + destination = @dest +) +`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + store := ObservableStore{ + StaticStore: interpreter.StaticStore{ + Balances: interpreter.Balances{}, + Meta: interpreter.AccountsMetadata{"account_that_needs_meta": {"k": "source2"}}, + }, + } + _, err := parseResult.Run(context.Background(), numscript.VariablesMap{ + "s1": "source1", + }, + &store, + ) + require.Nil(t, err) + + require.Equal(t, + []numscript.MetadataQuery{ + { + "account_that_needs_meta": {"k"}, + }, + }, + store.GetMetadataCalls) + + require.Equal(t, + []numscript.BalanceQuery{ + // TODO maybe those calls can be batched together + { + // this is required by the balance() call + "account_that_needs_balance": {"USD/2"}, + }, + { + // this is defined in the variables + "source1": {"COIN"}, + + // this is defined in account metadata + "source2": {"COIN"}, + + // this appears as literal + "source3": {"COIN"}, + }, + }, + store.GetBalancesCalls) +} + +func TestGetBalancesAllotment(t *testing.T) { + parseResult := numscript.Parse(`send [COIN 100] ( + source = { + 1/2 from @a + remaining from @b + } + destination = @dest +) +`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + store := ObservableStore{ + StaticStore: interpreter.StaticStore{ + Balances: interpreter.Balances{ + "a": {"COIN": big.NewInt(10000)}, + "b": {"COIN": big.NewInt(10000)}, + }, + }, + } + + _, err := parseResult.Run(context.Background(), + numscript.VariablesMap{}, + &store, + ) + require.Nil(t, err) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "a": {"COIN"}, + "b": {"COIN"}, + }, + }, + store.GetBalancesCalls) +} + +func TestGetBalancesOverdraft(t *testing.T) { + parseResult := numscript.Parse(`send [COIN 100] ( + source = { + @a allowing overdraft up to [COIN 10] + @b allowing unbounded overdraft + } + destination = @dest +) +`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + store := ObservableStore{} + + _, err := parseResult.Run(context.Background(), interpreter.VariablesMap{}, &store) + require.Nil(t, err) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "a": {"COIN"}, + }, + }, + store.GetBalancesCalls) +} + +func TestDoNotFetchBalanceTwice(t *testing.T) { + parseResult := numscript.Parse(`vars { monetary $v = balance(@src, COIN) } + + send $v ( + source = @src + destination = @dest + )`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src": {"COIN"}, + }, + }, + store.GetBalancesCalls, + ) + +} + +func TestDoNotFetchBalanceTwice2(t *testing.T) { + // same test as before, but this time the second batch is not empty + parseResult := numscript.Parse(`vars { monetary $v = balance(@src1, COIN) } + + send $v ( + source = { + @src1 + @src2 + } + destination = @dest + )`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src1": {"COIN"}, + }, + { + "src2": {"COIN"}, + }, + }, + store.GetBalancesCalls, + ) + +} + +func TestDoNotFetchBalanceTwice3(t *testing.T) { + // same test as before, but this time the second batch requires a _different asset_ + parseResult := numscript.Parse(`vars { monetary $eur_m = balance(@src, EUR/2) } + + + send [USD/2 100] ( + // note here we are fetching a different currency + source = @src + destination = @dest + ) +`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src": {"EUR/2"}, + }, + { + "src": {"USD/2"}, + }, + }, + store.GetBalancesCalls, + ) + +} + +func TestQueryBalanceErr(t *testing.T) { + parseResult := numscript.Parse(`send [COIN 100] ( + source = @src + destination = @dest +) +`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + _, err := parseResult.Run(context.Background(), interpreter.VariablesMap{}, &ErrorStore{}) + require.IsType(t, err, interpreter.QueryBalanceError{}) +} + +func TestMetadataFetchErr(t *testing.T) { + parseResult := numscript.Parse(`vars { + number $x = meta(@acc, "k") +}`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + _, err := parseResult.Run(context.Background(), interpreter.VariablesMap{}, &ErrorStore{}) + require.IsType(t, err, interpreter.QueryMetadataError{}) +} + +func TestBalanceFunctionErr(t *testing.T) { + parseResult := numscript.Parse(`vars { + monetary $x = balance(@acc, USD/2) +}`) + + require.Empty(t, parseResult.GetParsingErrors(), "There should not be parsing errors") + + _, err := parseResult.Run(context.Background(), interpreter.VariablesMap{}, &ErrorStore{}) + require.IsType(t, err, interpreter.QueryBalanceError{}) +} + +type ObservableStore struct { + StaticStore interpreter.StaticStore + GetBalancesCalls []numscript.BalanceQuery + GetMetadataCalls []numscript.MetadataQuery +} + +func (os *ObservableStore) GetBalances(ctx context.Context, q interpreter.BalanceQuery) (interpreter.Balances, error) { + os.GetBalancesCalls = append(os.GetBalancesCalls, q) + return os.StaticStore.GetBalances(ctx, q) + +} + +func (os *ObservableStore) GetAccountsMetadata(ctx context.Context, q interpreter.MetadataQuery) (interpreter.AccountsMetadata, error) { + os.GetMetadataCalls = append(os.GetMetadataCalls, q) + return os.StaticStore.GetAccountsMetadata(ctx, q) +} + +type ErrorStore struct{} + +func (*ErrorStore) GetBalances(ctx context.Context, q interpreter.BalanceQuery) (interpreter.Balances, error) { + return nil, errors.New("Error while fetching balances") +} + +func (*ErrorStore) GetAccountsMetadata(ctx context.Context, q interpreter.MetadataQuery) (interpreter.AccountsMetadata, error) { + return nil, errors.New("Error while fetching metadata") +}