Skip to content

Commit

Permalink
Fix walk to support AST struct member walk
Browse files Browse the repository at this point in the history
  • Loading branch information
auxten committed Sep 5, 2021
1 parent ff6ce86 commit ec47987
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 40 deletions.
23 changes: 23 additions & 0 deletions pkg/util/set/dedup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package set

import (
"sort"
)

func SortDeDup(l []string) []string {
n := len(l)
if n <= 1 {
return l
}
sort.Strings(l)

j := 1
for i := 1; i < n; i++ {
if l[i] != l[i-1] {
l[j] = l[i]
j++
}
}

return l[0:j]
}
49 changes: 49 additions & 0 deletions pkg/util/set/dedup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package set

import (
"fmt"
"testing"
)

func TestSortDeDup(t *testing.T) {
{
l := []string{"c", "a", "b", "c", "e", "d"}
expected := []string{"a", "b", "c", "d", "e"}
sl := SortDeDup(l)
if fmt.Sprint(sl) != fmt.Sprint(expected) {
t.Errorf("%v should be equal %v", sl, expected)
}
}
{
l := []string{"mj.mc_trscode", "mj.rowid", "mj.tans_amt", "mj.trans_bran_code", "mj.trans_date", "mj.trans_flag", "trans_date", "mj.trans_flag", "trans_date"}
expected := []string{"mj.mc_trscode", "mj.rowid", "mj.tans_amt", "mj.trans_bran_code", "mj.trans_date", "mj.trans_flag", "trans_date"}
sl := SortDeDup(l)
if fmt.Sprint(sl) != fmt.Sprint(expected) {
t.Errorf("%v should be equal %v", sl, expected)
}
}
{
l := []string{"c"}
expected := []string{"c"}
sl := SortDeDup(l)
if fmt.Sprint(sl) != fmt.Sprint(expected) {
t.Errorf("%v should be equal %v", sl, expected)
}
}
{
l := make([]string, 0)
expected := []string{}
sl := SortDeDup(l)
if fmt.Sprint(sl) != fmt.Sprint(expected) {
t.Errorf("%v should be equal %v", sl, expected)
}
}
{
var l []string
expected := []string(nil)
sl := SortDeDup(l)
if fmt.Sprint(sl) != fmt.Sprint(expected) {
t.Errorf("%v should be equal %v", sl, expected)
}
}
}
66 changes: 46 additions & 20 deletions pkg/walk/walker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,32 @@ package walk
import (
"fmt"
"log"
"sort"
"strings"

"github.com/auxten/postgresql-parser/pkg/sql/parser"
"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
"github.com/auxten/postgresql-parser/pkg/util/set"
)

type AstWalker struct {
unknownNodes []interface{}
UnknownNodes []interface{}
Fn func(ctx interface{}, node interface{}) (stop bool)
}
type ReferredCols map[string]int

func (rc ReferredCols) ToList() []string {
cols := make([]string, len(rc))
i := 0
for k, _ := range rc {
for k := range rc {
cols[i] = k
i++
}
sort.Strings(cols)
return cols
return set.SortDeDup(cols)
}

func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err error) {

w.unknownNodes = make([]interface{}, 0)
w.UnknownNodes = make([]interface{}, 0)
asts := make([]tree.NodeFormatter, len(stmts))
for si, stmt := range stmts {
asts[si] = stmt.AST
Expand Down Expand Up @@ -69,8 +68,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
walk(node.Left, node.Right)
case *tree.CaseExpr:
walk(node.Expr, node.Else)
for _, w := range node.Whens {
walk(w.Cond, w.Val)
for _, when := range node.Whens {
walk(when.Cond, when.Val)
}
case *tree.RangeCond:
walk(node.Left, node.From, node.To)
Expand Down Expand Up @@ -98,6 +97,11 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
walk(expr)
}
case *tree.FamilyTableDef:
case *tree.From:
walk(node.AsOf)
for _, table := range node.Tables {
walk(table)
}
case *tree.FuncExpr:
if node.WindowDef != nil {
walk(node.WindowDef)
Expand All @@ -111,6 +115,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
case *tree.NumVal:
case *tree.OnJoinCond:
walk(node.Expr)
case *tree.Order:
walk(node.Expr, node.Table)
case tree.OrderBy:
for _, order := range node {
walk(order)
}
case *tree.OrExpr:
walk(node.Left, node.Right)
case *tree.ParenExpr:
Expand All @@ -126,16 +136,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
walk(node.With)
}
if node.OrderBy != nil {
for _, order := range node.OrderBy {
walk(order)
}
walk(node.OrderBy)
}
if node.Limit != nil {
walk(node.Limit)
}
walk(node.Select)
case *tree.Order:
walk(node.Expr, node.Table)
case *tree.Limit:
walk(node.Count)
case *tree.SelectClause:
Expand All @@ -156,10 +162,7 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
walk(group)
}
}
walk(node.From.AsOf)
for _, table := range node.From.Tables {
walk(table)
}
walk(&node.From)
case tree.SelectExpr:
walk(node.Expr)
case tree.SelectExprs:
Expand All @@ -173,6 +176,10 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
case *tree.StrVal:
case *tree.Subquery:
walk(node.Select)
case tree.TableExprs:
for _, expr := range node {
walk(expr)
}
case *tree.TableName, tree.TableName:
case *tree.Tuple:
for _, expr := range node.Exprs {
Expand Down Expand Up @@ -214,8 +221,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
walk(expr)
}
default:
if w.unknownNodes != nil {
w.unknownNodes = append(w.unknownNodes, node)
if w.UnknownNodes != nil {
w.UnknownNodes = append(w.UnknownNodes, node)
}
}
}
Expand Down Expand Up @@ -270,8 +277,27 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
if err != nil {
return
}
for _, col := range w.unknownNodes {
for _, col := range w.UnknownNodes {
log.Printf("unhandled column type %T", col)
}
return
}

func AllColsContained(set ReferredCols, cols []string) bool {
if cols == nil {
if set == nil {
return true
} else {
return false
}
}
if len(set) != len(cols) {
return false
}
for _, col := range cols {
if _, exist := set[col]; !exist {
return false
}
}
return true
}
21 changes: 1 addition & 20 deletions pkg/walk/walker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,6 @@ func TestParser(t *testing.T) {
})
}

func allColsContained(set ReferredCols, cols []string) bool {
if cols == nil {
if set == nil {
return true
} else {
return false
}
}
if len(set) != len(cols) {
return false
}
for _, col := range cols {
if _, exist := set[col]; !exist {
return false
}
}
return true
}

func TestReferredVarsInSelectStatement(t *testing.T) {
testCases := []struct {
sql string
Expand Down Expand Up @@ -191,7 +172,7 @@ func TestReferredVarsInSelectStatement(t *testing.T) {
}

for _, tc := range testCases {
t.Run(tc.sql, func(t *testing.T) {
t.Run(tc.sql, func(t *testing.T) {
referredCols, err := func() (ReferredCols, error) {
return ColNamesInSelect(tc.sql)
}()
Expand Down

0 comments on commit ec47987

Please sign in to comment.