Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 31 additions & 70 deletions sql/transform/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,39 +112,29 @@ func OneNodeExprsWithNode(n sql.Node, f ExprWithNodeFunc) (sql.Node, TreeIdentit
return n, SameTree, nil
}

sameExprs := SameTree
exprs := ne.Expressions()
if len(exprs) == 0 {
return n, SameTree, nil
}

var (
newExprs []sql.Expression
err error
)

for i := range exprs {
e := exprs[i]
e, same, err := ExprWithNode(n, e, f)
for i, expr := range exprs {
newExpr, same, err := ExprWithNode(n, expr, f)
if err != nil {
return nil, SameTree, err
}
if !same {
if newExprs == nil {
newExprs = make([]sql.Expression, len(exprs))
copy(newExprs, exprs)
}
newExprs[i] = e
exprs[i] = newExpr
sameExprs = NewTree
}
}

if len(newExprs) > 0 {
n, err = ne.WithExpressions(newExprs...)
if err != nil {
return nil, SameTree, err
}
return n, NewTree, nil
if sameExprs {
return n, SameTree, nil
}
return n, SameTree, nil

var err error
n, err = ne.WithExpressions(exprs...)
if err != nil {
return nil, SameTree, err
}
return n, NewTree, nil
}

// OneNodeExpressions applies a transformation function to all expressions
Expand Down Expand Up @@ -307,44 +297,28 @@ func transformUpWithPrefixSchemaHelper(c Context, s SelectorFunc, f CtxFunc) (sq
return node, sameC && sameN, nil
}

// Node applies a transformation function to the given tree from the
// bottom up.
// Node applies a transformation function to the given tree from the bottom up.
func Node(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
_, ok := node.(sql.OpaqueNode)
if ok {
return f(node)
}

sameC := SameTree
children := node.Children()
if len(children) == 0 {
return f(node)
}

var (
newChildren []sql.Node
child sql.Node
)

for i := range children {
child = children[i]
child, same, err := Node(child, f)
for i, child := range children {
newChild, same, err := Node(child, f)
if err != nil {
return nil, SameTree, err
}
if !same {
if newChildren == nil {
newChildren = make([]sql.Node, len(children))
copy(newChildren, children)
}
newChildren[i] = child
children[i] = newChild
sameC = NewTree
}
}

var err error
sameC := SameTree
if len(newChildren) > 0 {
sameC = NewTree
node, err = node.WithChildren(newChildren...)
if !sameC {
var err error
node, err = node.WithChildren(children...)
if err != nil {
return nil, SameTree, err
}
Expand All @@ -361,35 +335,22 @@ func Node(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
// opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in
// isolation except for very specific exceptions.
func NodeWithOpaque(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
sameC := SameTree
children := node.Children()
if len(children) == 0 {
return f(node)
}

var (
newChildren []sql.Node
err error
)

for i := range children {
c := children[i]
c, same, err := NodeWithOpaque(c, f)
for i, child := range children {
newChild, same, err := NodeWithOpaque(child, f)
if err != nil {
return nil, SameTree, err
}
if !same {
if newChildren == nil {
newChildren = make([]sql.Node, len(children))
copy(newChildren, children)
}
newChildren[i] = c
children[i] = newChild
sameC = NewTree
}
}

sameC := SameTree
if len(newChildren) > 0 {
sameC = NewTree
node, err = node.WithChildren(newChildren...)
var err error
if !sameC {
node, err = node.WithChildren(children...)
if err != nil {
return nil, SameTree, err
}
Expand Down
16 changes: 12 additions & 4 deletions sql/transform/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,21 @@ func Inspect(node sql.Node, f func(sql.Node) bool) (cont bool) {
// Avoid allocating []sql.Expression
switch n := node.(type) {
case sql.UnaryNode:
Inspect(n.Child(), f)
if !Inspect(n.Child(), f) {
return false
}
case sql.BinaryNode:
Inspect(n.Left(), f)
Inspect(n.Right(), f)
if !Inspect(n.Left(), f) {
return false
}
if !Inspect(n.Right(), f) {
return false
}
default:
for _, child := range n.Children() {
Inspect(child, f)
if !Inspect(child, f) {
return false
}
}
}
return true
Expand Down
Loading