@@ -287,8 +287,9 @@ var ErrNilExpression = errors.New("juice: nil expression")
287287// ConditionNode represents a conditional SQL fragment with its evaluation expression and child nodes.
288288// It is used to conditionally include or exclude SQL fragments based on runtime parameters.
289289type ConditionNode struct {
290- expr eval.Expression
291- Nodes NodeGroup
290+ expr eval.Expression
291+ Nodes NodeGroup
292+ BindNodes BindNodeGroup
292293}
293294
294295// Parse compiles the given expression string into an evaluable expression.
@@ -312,13 +313,16 @@ func (c *ConditionNode) Parse(test string) (err error) {
312313// Accept accepts parameters and returns query and arguments.
313314// Accept implements Node interface.
314315func (c * ConditionNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
316+ p = c .BindNodes .ConvertParameter (p )
317+
315318 matched , err := c .Match (p )
316319 if err != nil {
317320 return "" , nil , err
318321 }
319322 if ! matched {
320323 return "" , nil , nil
321324 }
325+
322326 return c .Nodes .Accept (translator , p )
323327}
324328
@@ -332,6 +336,7 @@ func (c *ConditionNode) Match(p eval.Parameter) (bool, error) {
332336 if c .expr == nil {
333337 return false , ErrNilExpression
334338 }
339+
335340 value , err := c .expr .Execute (p )
336341 if err != nil {
337342 return false , err
@@ -363,7 +368,8 @@ var _ Node = (*IfNode)(nil)
363368// WhereNode represents a SQL WHERE clause and its conditions.
364369// It manages a group of condition nodes that form the complete WHERE clause.
365370type WhereNode struct {
366- Nodes NodeGroup
371+ Nodes NodeGroup
372+ BindNodes BindNodeGroup
367373}
368374
369375// Accept processes the WHERE clause and its conditions.
@@ -379,6 +385,8 @@ type WhereNode struct {
379385// Input: "WHERE age > ?" -> Output: "WHERE age > ?"
380386// Input: "status = ?" -> Output: "WHERE status = ?"
381387func (w WhereNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
388+ p = w .BindNodes .ConvertParameter (p )
389+
382390 query , args , err = w .Nodes .Accept (translator , p )
383391 if err != nil {
384392 return "" , nil , err
@@ -441,10 +449,13 @@ type TrimNode struct {
441449 PrefixOverrides []string
442450 Suffix string
443451 SuffixOverrides []string
452+ BindNodes BindNodeGroup
444453}
445454
446455// Accept accepts parameters and returns query and arguments.
447456func (t TrimNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
457+ p = t .BindNodes .ConvertParameter (p )
458+
448459 query , args , err = t .Nodes .Accept (translator , p )
449460 if err != nil {
450461 return "" , nil , err
@@ -541,10 +552,12 @@ type ForeachNode struct {
541552 Open string
542553 Close string
543554 Separator string
555+ BindNodes BindNodeGroup
544556}
545557
546558// Accept accepts parameters and returns query and arguments.
547559func (f ForeachNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
560+ p = f .BindNodes .ConvertParameter (p )
548561
549562 // if item already exists
550563 if _ , exists := p .Get (f .Item ); exists {
@@ -752,18 +765,22 @@ var _ Node = (*ForeachNode)(nil)
752765// proper formatting of the SET clause regardless of which fields
753766// are included dynamically.
754767type SetNode struct {
755- Nodes NodeGroup
768+ Nodes NodeGroup
769+ BindNodes BindNodeGroup
756770}
757771
758772// Accept accepts parameters and returns query and arguments.
759773func (s SetNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
774+ p = s .BindNodes .ConvertParameter (p )
775+
760776 query , args , err = s .Nodes .Accept (translator , p )
761777 if err != nil {
762778 return "" , nil , err
763779 }
764780 if len (query ) == 0 {
765781 return "" , args , nil
766782 }
783+
767784 // Remove trailing comma
768785 query = strings .TrimSuffix (query , "," )
769786
@@ -814,8 +831,9 @@ var _ Node = (*SetNode)(nil)
814831// Note: The id must be unique within its mapper context to allow
815832// proper statement lookup and execution.
816833type SQLNode struct {
817- id string // Unique identifier for the SQL statement
818- nodes NodeGroup // Child nodes forming the SQL statement
834+ id string // Unique identifier for the SQL statement
835+ nodes NodeGroup // Child nodes forming the SQL statement
836+ BindNodes BindNodeGroup
819837}
820838
821839// ID returns the id of the node.
@@ -825,6 +843,8 @@ func (s SQLNode) ID() string {
825843
826844// Accept accepts parameters and returns query and arguments.
827845func (s SQLNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
846+ p = s .BindNodes .ConvertParameter (p )
847+
828848 return s .nodes .Accept (translator , p )
829849}
830850
@@ -886,6 +906,7 @@ func (i *IncludeNode) Accept(translator driver.Translator, p eval.Parameter) (qu
886906 }
887907 i .sqlNode = sqlNode
888908 }
909+
889910 return i .sqlNode .Accept (translator , p )
890911}
891912
@@ -939,10 +960,13 @@ var _ Node = (*IncludeNode)(nil)
939960type ChooseNode struct {
940961 WhenNodes []Node
941962 OtherwiseNode Node
963+ BindNodes BindNodeGroup
942964}
943965
944966// Accept accepts parameters and returns query and arguments.
945967func (c ChooseNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
968+ p = c .BindNodes .ConvertParameter (p )
969+
946970 for _ , node := range c .WhenNodes {
947971 q , a , err := node .Accept (translator , p )
948972 if err != nil {
@@ -953,6 +977,7 @@ func (c ChooseNode) Accept(translator driver.Translator, p eval.Parameter) (quer
953977 return q , a , nil
954978 }
955979 }
980+
956981 // if all when nodes are false, return otherwise node
957982 if c .OtherwiseNode != nil {
958983 return c .OtherwiseNode .Accept (translator , p )
@@ -1041,11 +1066,14 @@ var _ Node = (*WhenNode)(nil)
10411066// Note: Unlike WhenNode, OtherwiseNode doesn't evaluate any conditions.
10421067// It simply provides default SQL fragments when needed.
10431068type OtherwiseNode struct {
1044- Nodes NodeGroup
1069+ Nodes NodeGroup
1070+ BindNodes BindNodeGroup
10451071}
10461072
10471073// Accept accepts parameters and returns query and arguments.
10481074func (o OtherwiseNode ) Accept (translator driver.Translator , p eval.Parameter ) (query string , args []any , err error ) {
1075+ p = o .BindNodes .ConvertParameter (p )
1076+
10491077 return o .Nodes .Accept (translator , p )
10501078}
10511079
@@ -1100,6 +1128,33 @@ func (b *BindNode) Execute(p eval.Parameter) (reflect.Value, error) {
11001128 return value , nil
11011129}
11021130
1131+ type BindNodeGroup []* BindNode
1132+
1133+ func (b BindNodeGroup ) ConvertParameter (parameter eval.Parameter ) eval.Parameter {
1134+ if len (b ) == 0 {
1135+ return parameter
1136+ }
1137+ // decorate the parameter with boundParameterDecorator
1138+ // to provide binding scope for bind variables
1139+ boundParam := & boundParameterDecorator {
1140+ scope : & bindScope {
1141+ nodes : b ,
1142+ parameter : parameter ,
1143+ },
1144+ }
1145+
1146+ parameter = eval.ParamGroup {
1147+ boundParam ,
1148+ parameter ,
1149+ }
1150+ // another approach is to use ParamGroup to combine boundParam and parameter
1151+ // but the order matters here.
1152+ // if we put boundParam after parameter, the boundParam will have lower priority
1153+ // than the original parameter, which is not what we want.
1154+ // so we put boundParam before parameter.
1155+ return parameter
1156+ }
1157+
11031158// ErrBindVariableNotFound is returned when a bind variable lookup fails.
11041159var ErrBindVariableNotFound = errors .New ("juice: bind variable not found" )
11051160
0 commit comments