From c654bf77114ff5fd1020f8d34a46b653669382ee Mon Sep 17 00:00:00 2001 From: Amirsalar Safaei Date: Fri, 30 Aug 2024 17:16:49 +0330 Subject: [PATCH] [Draft] save inlinings --- rego/rego.go | 9 +++++++-- topdown/eval.go | 45 ++++++++++++++++++++++++++++++++++++++++++--- topdown/query.go | 17 +++++++++-------- 3 files changed, 58 insertions(+), 13 deletions(-) diff --git a/rego/rego.go b/rego/rego.go index 266e6d6ab0..61bf82c757 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "os" "strings" "time" @@ -2366,6 +2367,8 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)} } + tracer := topdown.NewBufferTracer() + q := topdown.NewQuery(ectx.compiledQuery.query). WithQueryCompiler(ectx.compiledQuery.compiler). WithCompiler(r.compiler). @@ -2385,7 +2388,9 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). WithStrictBuiltinErrors(ectx.strictBuiltinErrors). WithSeed(ectx.seed). - WithPrintHook(ectx.printHook) + WithPrintHook(ectx.printHook). + WithQueryTracer(tracer) + if !ectx.time.IsZero() { q = q.WithTime(ectx.time) @@ -2467,7 +2472,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, support[i].SetRegoVersion(r.regoVersion) } } - + topdown.PrettyTraceWithLocation(os.Stdout, *tracer) pq := &PartialQueries{ Queries: queries, Support: support, diff --git a/topdown/eval.go b/topdown/eval.go index 6263efba64..c155164ac9 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -107,6 +107,17 @@ type eval struct { tracingOpts tracing.Options findOne bool strictObjects bool + inliningCacheList inliningCacheList +} + +type inliningCache struct { + rule *ast.Rule + term *ast.Term + termbindings *bindings +} + +type inliningCacheList struct { + list []inliningCache } func (e *eval) Run(iter evalIterator) error { @@ -160,7 +171,7 @@ func (e *eval) closure(query ast.Body) *eval { return &cpy } -func (e *eval) child(query ast.Body) *eval { +func (e *eval) child(query ast.Body, a ...bool) *eval { cpy := *e cpy.index = 0 cpy.query = query @@ -168,6 +179,9 @@ func (e *eval) child(query ast.Body) *eval { cpy.bindings = newBindings(cpy.queryID, e.instr) cpy.parent = e cpy.findOne = false + if len(a) == 0 { + cpy.inliningCacheList = inliningCacheList{} + } return &cpy } @@ -3003,7 +3017,11 @@ func (e evalVirtualComplete) eval(iter unifyIterator) error { return e.partialEvalSupport(iter) } - return e.partialEval(iter) + cpy := e + if e.e.parent != nil { + cpy.e = e.e.parent + } + return cpy.partialEval(iter) } func (e evalVirtualComplete) evalValue(iter unifyIterator, findOne bool) error { @@ -3115,13 +3133,34 @@ func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, p func (e evalVirtualComplete) partialEval(iter unifyIterator) error { for _, rule := range e.ir.Rules { + found := false + + for _ , cache := range e.e.inliningCacheList.list { + if cache.rule.Equal(rule) { + found = true + err := e.evalTerm(iter, cache.term, cache.termbindings) + if err != nil { + return err + } + } + } + + if found { + fmt.Printf("skipped %s\n", rule.Ref().String()) + continue + } + child := e.e.child(rule.Body) child.traceEnter(rule) - err := child.eval(func(child *eval) error { child.traceExit(rule) term, termbindings := child.bindings.apply(rule.Head.Value) + e.e.inliningCacheList.list = append(e.e.inliningCacheList.list, inliningCache{ + rule: rule, + term: term, + termbindings: termbindings, + }) err := e.evalTerm(iter, term, termbindings) if err != nil { return err diff --git a/topdown/query.go b/topdown/query.go index 2b540c58a5..1569798760 100644 --- a/topdown/query.go +++ b/topdown/query.go @@ -350,13 +350,14 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] inliningControl: &inliningControl{ shallow: q.shallowInlining, }, - genvarprefix: q.genvarprefix, - runtime: q.runtime, - indexing: q.indexing, - earlyExit: q.earlyExit, - builtinErrors: &builtinErrors{}, - printHook: q.printHook, - strictObjects: q.strictObjects, + genvarprefix: q.genvarprefix, + runtime: q.runtime, + indexing: q.indexing, + earlyExit: q.earlyExit, + builtinErrors: &builtinErrors{}, + printHook: q.printHook, + strictObjects: q.strictObjects, + inliningCacheList: inliningCacheList{}, } if len(q.disableInlining) > 0 { @@ -387,7 +388,6 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] p := copypropagation.New(livevars).WithCompiler(q.compiler) err = e.Run(func(e *eval) error { - // Build output from saved expressions. body := ast.NewBody() @@ -526,6 +526,7 @@ func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error { printHook: q.printHook, tracingOpts: q.tracingOpts, strictObjects: q.strictObjects, + inliningCacheList: inliningCacheList{}, } e.caller = e q.metrics.Timer(metrics.RegoQueryEval).Start()