diff --git a/internal/interpreter/batch_balances_query.go b/internal/interpreter/batch_balances_query.go index 2a07ed9..91b9ae0 100644 --- a/internal/interpreter/batch_balances_query.go +++ b/internal/interpreter/batch_balances_query.go @@ -15,26 +15,27 @@ func (st *programState) findBalancesQueriesInStatement(statement parser.Statemen case *parser.FnCall: return nil - case *parser.SendStatement: - // set the current asset - switch sentValue := statement.SentValue.(type) { - case *parser.SentValueAll: - asset, err := evaluateLitExpecting(st, sentValue.Asset, expectAsset) - if err != nil { - return err - } - st.CurrentAsset = *asset - - case *parser.SentValueLiteral: - monetary, err := evaluateLitExpecting(st, sentValue.Monetary, expectMonetary) - if err != nil { - return err - } - st.CurrentAsset = string(monetary.Asset) + case *parser.SaveStatement: + // Although we don't technically need this account's balance rn, + // having access to the balance simplifies the "save" statement implementation + // this means that we would have a needless query in the case in which the account + // which is selected in the "save" statement never actually appears as source + // + // this would mean that the "save" statement was not needed in the first place, + // so preventing this query would hardly be an useful optimization + account, err := evaluateLitExpecting(st, statement.Literal, expectAccount) + if err != nil { + return err + } + st.batchQuery(*account, st.CurrentAsset) + return nil - default: - utils.NonExhaustiveMatchPanic[any](sentValue) + case *parser.SendStatement: + asset, _, err := st.evaluateSentAmt(statement.SentValue) + if err != nil { + return err } + st.CurrentAsset = *asset // traverse source return st.findBalancesQueries(statement.Source) diff --git a/internal/interpreter/interpreter.go b/internal/interpreter/interpreter.go index 535c40a..fdc11d7 100644 --- a/internal/interpreter/interpreter.go +++ b/internal/interpreter/interpreter.go @@ -277,6 +277,10 @@ func (st *programState) runStatement(statement parser.Statement) ([]Posting, Int case *parser.SendStatement: return st.runSendStatement(*statement) + + case *parser.SaveStatement: + return st.runSaveStatement(*statement) + default: utils.NonExhaustiveMatchPanic[any](statement) return nil, nil @@ -299,6 +303,41 @@ func (st *programState) getPostings() ([]Posting, InterpreterError) { return postings, nil } +func (st *programState) runSaveStatement(saveStatement parser.SaveStatement) ([]Posting, InterpreterError) { + asset, amt, err := st.evaluateSentAmt(saveStatement.SentValue) + if err != nil { + return nil, err + } + + account, err := evaluateLitExpecting(st, saveStatement.Literal, expectAccount) + if err != nil { + return nil, err + } + + balance := st.getCachedBalance(*account, *asset) + + if amt == nil { + balance.Set(big.NewInt(0)) + } else { + // Do not allow negative saves + if amt.Cmp(big.NewInt(0)) == -1 { + return nil, NegativeAmountErr{ + Range: saveStatement.SentValue.GetRange(), + Amount: MonetaryInt(*amt), + } + } + + // we decrease the balance by "amt" + balance.Sub(balance, amt) + // without going under 0 + if balance.Cmp(big.NewInt(0)) == -1 { + balance.Set(big.NewInt(0)) + } + } + + return nil, nil +} + func (st *programState) runSendStatement(statement parser.SendStatement) ([]Posting, InterpreterError) { switch sentValue := statement.SentValue.(type) { case *parser.SentValueAll: @@ -796,3 +835,27 @@ func setAccountMeta(st *programState, r parser.Range, args []Value) InterpreterE return nil } + +func (st *programState) evaluateSentAmt(sentValue parser.SentValue) (*string, *big.Int, InterpreterError) { + switch sentValue := sentValue.(type) { + case *parser.SentValueAll: + asset, err := evaluateLitExpecting(st, sentValue.Asset, expectAsset) + if err != nil { + return nil, nil, err + } + return asset, nil, nil + + case *parser.SentValueLiteral: + monetary, err := evaluateLitExpecting(st, sentValue.Monetary, expectMonetary) + if err != nil { + return nil, nil, err + } + s := string(monetary.Asset) + bi := big.Int(monetary.Amount) + return &s, &bi, nil + + default: + utils.NonExhaustiveMatchPanic[any](sentValue) + return nil, nil, nil + } +} diff --git a/internal/interpreter/interpreter_test.go b/internal/interpreter/interpreter_test.go index fa75330..3c04f50 100644 --- a/internal/interpreter/interpreter_test.go +++ b/internal/interpreter/interpreter_test.go @@ -50,7 +50,9 @@ func removeRange(e machine.InterpreterError) machine.InterpreterError { case machine.InvalidTypeErr: e.Range = parser.Range{} return e - + case machine.NegativeAmountErr: + e.Range = parser.Range{} + return e default: return e } @@ -2805,3 +2807,278 @@ func TestBigIntMonetary(t *testing.T) { } test(t, tc) } + +// TODO +// func TestSaveFromAccountBatchQuery(t *testing.T) {} + +func TestSaveFromAccount(t *testing.T) { + + t.Run("simple", func(t *testing.T) { + script := ` + save [USD 10] from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "USD", + Amount: big.NewInt(10), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(20), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("save all", func(t *testing.T) { + script := ` + save [USD *] from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + // 0-posting omitted + { + Asset: "USD", + Amount: big.NewInt(30), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("save more than balance", func(t *testing.T) { + script := ` + save [USD 30] from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + // 0-posting omitted + { + Asset: "USD", + Amount: big.NewInt(30), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("with asset var", func(t *testing.T) { + script := ` + vars { + asset $ass + } + save [$ass 10] from @alice + + send [$ass 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.vars = map[string]string{ + "ass": "USD", + } + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "USD", + Amount: big.NewInt(10), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(20), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("with monetary var", func(t *testing.T) { + script := ` + vars { + monetary $mon + } + + save $mon from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.vars = map[string]string{ + "mon": "USD 10", + } + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "USD", + Amount: big.NewInt(10), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(20), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("multi postings", func(t *testing.T) { + script := ` + send [USD 10] ( + source = @alice + destination = @bob + ) + + save [USD 5] from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "USD", + Amount: big.NewInt(10), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(5), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(25), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("save a different asset", func(t *testing.T) { + script := ` + save [COIN 100] from @alice + + send [USD 30] ( + source = { + @alice + @world + } + destination = @bob + )` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("alice", "COIN", 100) + tc.setBalance("alice", "USD", 20) + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "USD", + Amount: big.NewInt(20), + Source: "alice", + Destination: "bob", + }, + { + Asset: "USD", + Amount: big.NewInt(10), + Source: "world", + Destination: "bob", + }, + }, + Error: nil, + } + test(t, tc) + }) + + t.Run("negative amount", func(t *testing.T) { + script := ` + + save [USD -100] from @A` + tc := NewTestCase() + tc.compile(t, script) + tc.setBalance("A", "USD", -100) + tc.expected = CaseResult{ + Postings: []Posting{}, + Error: machine.NegativeAmountErr{ + Amount: machine.NewMonetaryInt(-100), + }, + } + test(t, tc) + }) +}