diff --git a/internal/interpreter/batch_balances_query.go b/internal/interpreter/batch_balances_query.go index b5c4da3..ff0e7e5 100644 --- a/internal/interpreter/batch_balances_query.go +++ b/internal/interpreter/batch_balances_query.go @@ -5,6 +5,7 @@ import ( "github.com/formancehq/numscript/internal/parser" "github.com/formancehq/numscript/internal/utils" + "golang.org/x/exp/maps" ) // traverse the script to batch in advance required balance queries @@ -56,7 +57,28 @@ func (st *programState) batchQuery(account string, asset string) { } func (st *programState) runBalancesQuery() error { - balances, err := st.Store.GetBalances(st.ctx, st.CurrentBalanceQuery) + filteredQuery := BalanceQuery{} + for accountName, queriedCurrencies := range st.CurrentBalanceQuery { + + cachedCurrenciesForAccount := defaultMapGet(st.CachedBalances, accountName, func() AccountBalance { + return AccountBalance{} + }) + + for _, queriedCurrency := range queriedCurrencies { + isAlreadyCached := slices.Contains(maps.Keys(cachedCurrenciesForAccount), queriedCurrency) + if !isAlreadyCached { + filteredQuery[accountName] = queriedCurrencies + } + } + + } + + // avoid updating balances if we don't need to fetch new data + if len(filteredQuery) == 0 { + return nil + } + + balances, err := st.Store.GetBalances(st.ctx, filteredQuery) if err != nil { return err } diff --git a/internal/interpreter/interpreter_test.go b/internal/interpreter/interpreter_test.go index 9423aa3..8b8dba5 100644 --- a/internal/interpreter/interpreter_test.go +++ b/internal/interpreter/interpreter_test.go @@ -2657,3 +2657,29 @@ func TestCascadingSources(t *testing.T) { } test(t, tc) } + +func TestUseBalanceTwice(t *testing.T) { + tc := NewTestCase() + tc.compile(t, ` + vars { monetary $v = balance(@src, COIN) } + + send $v ( + source = @src + destination = @dest + )`) + + tc.setBalance("src", "COIN", 50) + tc.expected = CaseResult{ + + Postings: []Posting{ + { + Asset: "COIN", + Amount: big.NewInt(50), + Source: "src", + Destination: "dest", + }, + }, + Error: nil, + } + test(t, tc) +} diff --git a/numscript_test.go b/numscript_test.go index 9ff1524..7deb4f5 100644 --- a/numscript_test.go +++ b/numscript_test.go @@ -136,6 +136,86 @@ func TestGetBalancesOverdraft(t *testing.T) { store.GetBalancesCalls) } +func TestDoNotFetchBalanceTwice(t *testing.T) { + parseResult := numscript.Parse(`vars { monetary $v = balance(@src, COIN) } + + send $v ( + source = @src + destination = @dest + )`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src": {"COIN"}, + }, + }, + store.GetBalancesCalls, + ) + +} + +func TestDoNotFetchBalanceTwice2(t *testing.T) { + // same test as before, but this time the second batch is not empty + parseResult := numscript.Parse(`vars { monetary $v = balance(@src1, COIN) } + + send $v ( + source = { + @src1 + @src2 + } + destination = @dest + )`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src1": {"COIN"}, + }, + { + "src2": {"COIN"}, + }, + }, + store.GetBalancesCalls, + ) + +} + +func TestDoNotFetchBalanceTwice3(t *testing.T) { + // same test as before, but this time the second batch requires a _different asset_ + parseResult := numscript.Parse(`vars { monetary $eur_m = balance(@src, EUR/2) } + + + send [USD/2 100] ( + // note here we are fetching a different currency + source = @src + destination = @dest + ) +`) + + store := ObservableStore{} + parseResult.Run(context.Background(), nil, &store) + + require.Equal(t, + []numscript.BalanceQuery{ + { + "src": {"EUR/2"}, + }, + { + "src": {"USD/2"}, + }, + }, + store.GetBalancesCalls, + ) + +} + func TestQueryBalanceErr(t *testing.T) { parseResult := numscript.Parse(`send [COIN 100] ( source = @src