Skip to content

Commit 1f53611

Browse files
authored
feat: support scoped <bind> tags in dynamic SQL nodes (#526)
- Enable <bind> tags to be nested within dynamic SQL elements (if, where, trim, foreach, choose, when, otherwise, sql). - Implement BindNodeGroup to handle local variable injection for dynamic nodes. - Refactor Node.Accept methods to apply local binds to the parameter context. - Remove BindNodes() from Statement interface as binding is now handled internally.
1 parent de9dcec commit 1f53611

File tree

4 files changed

+134
-50
lines changed

4 files changed

+134
-50
lines changed

node.go

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ var ErrNilExpression = errors.New("juice: nil expression")
287287
// ConditionNode represents a conditional SQL fragment with its evaluation expression and child nodes.
288288
// It is used to conditionally include or exclude SQL fragments based on runtime parameters.
289289
type ConditionNode struct {
290-
expr eval.Expression
291-
Nodes NodeGroup
290+
expr eval.Expression
291+
Nodes NodeGroup
292+
BindNodes BindNodeGroup
292293
}
293294

294295
// Parse compiles the given expression string into an evaluable expression.
@@ -312,13 +313,16 @@ func (c *ConditionNode) Parse(test string) (err error) {
312313
// Accept accepts parameters and returns query and arguments.
313314
// Accept implements Node interface.
314315
func (c *ConditionNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
316+
p = c.BindNodes.ConvertParameter(p)
317+
315318
matched, err := c.Match(p)
316319
if err != nil {
317320
return "", nil, err
318321
}
319322
if !matched {
320323
return "", nil, nil
321324
}
325+
322326
return c.Nodes.Accept(translator, p)
323327
}
324328

@@ -332,6 +336,7 @@ func (c *ConditionNode) Match(p eval.Parameter) (bool, error) {
332336
if c.expr == nil {
333337
return false, ErrNilExpression
334338
}
339+
335340
value, err := c.expr.Execute(p)
336341
if err != nil {
337342
return false, err
@@ -363,7 +368,8 @@ var _ Node = (*IfNode)(nil)
363368
// WhereNode represents a SQL WHERE clause and its conditions.
364369
// It manages a group of condition nodes that form the complete WHERE clause.
365370
type WhereNode struct {
366-
Nodes NodeGroup
371+
Nodes NodeGroup
372+
BindNodes BindNodeGroup
367373
}
368374

369375
// Accept processes the WHERE clause and its conditions.
@@ -379,6 +385,8 @@ type WhereNode struct {
379385
// Input: "WHERE age > ?" -> Output: "WHERE age > ?"
380386
// Input: "status = ?" -> Output: "WHERE status = ?"
381387
func (w WhereNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
388+
p = w.BindNodes.ConvertParameter(p)
389+
382390
query, args, err = w.Nodes.Accept(translator, p)
383391
if err != nil {
384392
return "", nil, err
@@ -441,10 +449,13 @@ type TrimNode struct {
441449
PrefixOverrides []string
442450
Suffix string
443451
SuffixOverrides []string
452+
BindNodes BindNodeGroup
444453
}
445454

446455
// Accept accepts parameters and returns query and arguments.
447456
func (t TrimNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
457+
p = t.BindNodes.ConvertParameter(p)
458+
448459
query, args, err = t.Nodes.Accept(translator, p)
449460
if err != nil {
450461
return "", nil, err
@@ -541,10 +552,12 @@ type ForeachNode struct {
541552
Open string
542553
Close string
543554
Separator string
555+
BindNodes BindNodeGroup
544556
}
545557

546558
// Accept accepts parameters and returns query and arguments.
547559
func (f ForeachNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
560+
p = f.BindNodes.ConvertParameter(p)
548561

549562
// if item already exists
550563
if _, exists := p.Get(f.Item); exists {
@@ -752,18 +765,22 @@ var _ Node = (*ForeachNode)(nil)
752765
// proper formatting of the SET clause regardless of which fields
753766
// are included dynamically.
754767
type SetNode struct {
755-
Nodes NodeGroup
768+
Nodes NodeGroup
769+
BindNodes BindNodeGroup
756770
}
757771

758772
// Accept accepts parameters and returns query and arguments.
759773
func (s SetNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
774+
p = s.BindNodes.ConvertParameter(p)
775+
760776
query, args, err = s.Nodes.Accept(translator, p)
761777
if err != nil {
762778
return "", nil, err
763779
}
764780
if len(query) == 0 {
765781
return "", args, nil
766782
}
783+
767784
// Remove trailing comma
768785
query = strings.TrimSuffix(query, ",")
769786

@@ -814,8 +831,9 @@ var _ Node = (*SetNode)(nil)
814831
// Note: The id must be unique within its mapper context to allow
815832
// proper statement lookup and execution.
816833
type SQLNode struct {
817-
id string // Unique identifier for the SQL statement
818-
nodes NodeGroup // Child nodes forming the SQL statement
834+
id string // Unique identifier for the SQL statement
835+
nodes NodeGroup // Child nodes forming the SQL statement
836+
BindNodes BindNodeGroup
819837
}
820838

821839
// ID returns the id of the node.
@@ -825,6 +843,8 @@ func (s SQLNode) ID() string {
825843

826844
// Accept accepts parameters and returns query and arguments.
827845
func (s SQLNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
846+
p = s.BindNodes.ConvertParameter(p)
847+
828848
return s.nodes.Accept(translator, p)
829849
}
830850

@@ -886,6 +906,7 @@ func (i *IncludeNode) Accept(translator driver.Translator, p eval.Parameter) (qu
886906
}
887907
i.sqlNode = sqlNode
888908
}
909+
889910
return i.sqlNode.Accept(translator, p)
890911
}
891912

@@ -939,10 +960,13 @@ var _ Node = (*IncludeNode)(nil)
939960
type ChooseNode struct {
940961
WhenNodes []Node
941962
OtherwiseNode Node
963+
BindNodes BindNodeGroup
942964
}
943965

944966
// Accept accepts parameters and returns query and arguments.
945967
func (c ChooseNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
968+
p = c.BindNodes.ConvertParameter(p)
969+
946970
for _, node := range c.WhenNodes {
947971
q, a, err := node.Accept(translator, p)
948972
if err != nil {
@@ -953,6 +977,7 @@ func (c ChooseNode) Accept(translator driver.Translator, p eval.Parameter) (quer
953977
return q, a, nil
954978
}
955979
}
980+
956981
// if all when nodes are false, return otherwise node
957982
if c.OtherwiseNode != nil {
958983
return c.OtherwiseNode.Accept(translator, p)
@@ -1041,11 +1066,14 @@ var _ Node = (*WhenNode)(nil)
10411066
// Note: Unlike WhenNode, OtherwiseNode doesn't evaluate any conditions.
10421067
// It simply provides default SQL fragments when needed.
10431068
type OtherwiseNode struct {
1044-
Nodes NodeGroup
1069+
Nodes NodeGroup
1070+
BindNodes BindNodeGroup
10451071
}
10461072

10471073
// Accept accepts parameters and returns query and arguments.
10481074
func (o OtherwiseNode) Accept(translator driver.Translator, p eval.Parameter) (query string, args []any, err error) {
1075+
p = o.BindNodes.ConvertParameter(p)
1076+
10491077
return o.Nodes.Accept(translator, p)
10501078
}
10511079

@@ -1100,6 +1128,33 @@ func (b *BindNode) Execute(p eval.Parameter) (reflect.Value, error) {
11001128
return value, nil
11011129
}
11021130

1131+
type BindNodeGroup []*BindNode
1132+
1133+
func (b BindNodeGroup) ConvertParameter(parameter eval.Parameter) eval.Parameter {
1134+
if len(b) == 0 {
1135+
return parameter
1136+
}
1137+
// decorate the parameter with boundParameterDecorator
1138+
// to provide binding scope for bind variables
1139+
boundParam := &boundParameterDecorator{
1140+
scope: &bindScope{
1141+
nodes: b,
1142+
parameter: parameter,
1143+
},
1144+
}
1145+
1146+
parameter = eval.ParamGroup{
1147+
boundParam,
1148+
parameter,
1149+
}
1150+
// another approach is to use ParamGroup to combine boundParam and parameter
1151+
// but the order matters here.
1152+
// if we put boundParam after parameter, the boundParam will have lower priority
1153+
// than the original parameter, which is not what we want.
1154+
// so we put boundParam before parameter.
1155+
return parameter
1156+
}
1157+
11031158
// ErrBindVariableNotFound is returned when a bind variable lookup fails.
11041159
var ErrBindVariableNotFound = errors.New("juice: bind variable not found")
11051160

param.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,6 @@ func buildStatementParameters(param any, statement Statement, driverName string,
2929
eval.PrefixPatternParameter("_parameter", param),
3030
}
3131

32-
if bindNodes := statement.BindNodes(); len(bindNodes) > 0 {
33-
// decorate the parameter with boundParameterDecorator
34-
// to provide binding scope for bind variables
35-
boundParam := &boundParameterDecorator{
36-
scope: &bindScope{
37-
nodes: bindNodes,
38-
parameter: parameter,
39-
},
40-
}
41-
42-
boundParameter := make(eval.ParamGroup, 0, len(parameter)+1)
43-
boundParameter = append(boundParameter, boundParam)
44-
parameter = append(boundParameter, parameter...)
45-
46-
// another approach is to use ParamGroup to combine boundParam and parameter
47-
// but the order matters here.
48-
// if we put boundParam after parameter, the boundParam will have lower priority
49-
// than the original parameter, which is not what we want.
50-
// so we put boundParam before parameter.
51-
}
52-
5332
return parameter
5433
}
5534

0 commit comments

Comments
 (0)