From 5ca790c888d46d3d9fcb51997e1c7460270a088e Mon Sep 17 00:00:00 2001 From: auxten Date: Wed, 28 Jul 2021 11:52:02 +0800 Subject: [PATCH] Optimize ast walker --- example/walk.go | 8 +++++++- pkg/walk/walker.go | 36 ++++++++++++++++++++++++------------ pkg/walk/walker_test.go | 2 +- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/example/walk.go b/example/walk.go index 64e6f08..7e0fa8a 100644 --- a/example/walk.go +++ b/example/walk.go @@ -3,6 +3,7 @@ package main import ( "log" + "github.com/auxten/postgresql-parser/pkg/sql/parser" "github.com/auxten/postgresql-parser/pkg/walk" ) @@ -19,6 +20,11 @@ func main() { return false }, } - _, _ = w.Walk(sql, nil) + stmts, err := parser.Parse(sql) + if err != nil { + return + } + + _, _ = w.Walk(stmts, nil) return } diff --git a/pkg/walk/walker.go b/pkg/walk/walker.go index 6c506b6..9c8e151 100644 --- a/pkg/walk/walker.go +++ b/pkg/walk/walker.go @@ -27,11 +27,7 @@ func (rc ReferredCols) ToList() []string { return cols } -func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { - stmts, err := parser.Parse(sql) - if err != nil { - return false, err - } +func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err error) { w.unknownNodes = make([]interface{}, 0) asts := make([]tree.NodeFormatter, len(stmts)) @@ -67,6 +63,8 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { walk(node.Expr) case *tree.Array: walk(node.Exprs) + case tree.AsOfClause: + walk(node.Expr) case *tree.BinaryExpr: walk(node.Left, node.Right) case *tree.CaseExpr: @@ -127,7 +125,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { if node.With != nil { walk(node.With) } - walk(node.Select) if node.OrderBy != nil { for _, order := range node.OrderBy { walk(order) @@ -136,6 +133,7 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { if node.Limit != nil { walk(node.Limit) } + walk(node.Select) case *tree.Order: walk(node.Expr, node.Table) case *tree.Limit: @@ -148,9 +146,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { if node.Having != nil { walk(node.Having) } - for _, table := range node.From.Tables { - walk(table) - } if node.DistinctOn != nil { for _, distinct := range node.DistinctOn { walk(distinct) @@ -161,6 +156,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { walk(group) } } + walk(node.From.AsOf) + for _, table := range node.From.Tables { + walk(table) + } case tree.SelectExpr: walk(node.Expr) case tree.SelectExprs: @@ -192,6 +191,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { } case *tree.Where: walk(node.Expr) + case tree.Window: + for _, windowDef := range node { + walk(windowDef) + } case *tree.WindowDef: walk(node.Partitions) if node.Frame != nil { @@ -206,13 +209,14 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) { } case *tree.WindowFrameBound: walk(node.OffsetExpr) - case *tree.Window: case *tree.With: for _, expr := range node.CTEList { walk(expr) } default: - w.unknownNodes = append(w.unknownNodes, node) + if w.unknownNodes != nil { + w.unknownNodes = append(w.unknownNodes, node) + } } } } @@ -257,7 +261,15 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) { return false }, } - _, err = w.Walk(sql, referredCols) + stmts, err := parser.Parse(sql) + if err != nil { + return + } + + _, err = w.Walk(stmts, referredCols) + if err != nil { + return + } for _, col := range w.unknownNodes { log.Printf("unhandled column type %T", col) } diff --git a/pkg/walk/walker_test.go b/pkg/walk/walker_test.go index 8c236ad..5561bde 100644 --- a/pkg/walk/walker_test.go +++ b/pkg/walk/walker_test.go @@ -195,7 +195,7 @@ func TestReferredVarsInSelectStatement(t *testing.T) { referredCols, err := func() (ReferredCols, error) { return ColNamesInSelect(tc.sql) }() - if err.Error() != tc.err.Error() { + if err != nil && err.Error() != tc.err.Error() { t.Errorf("Expect %s, got %s", tc.err, err) } cols := referredCols.ToList()