3030 ErrQueryArgTODO = errors .New ("TODO: support this type" )
3131)
3232
33+ const (
34+ sqlxLib = "github.com/jmoiron/sqlx"
35+ dbSqlLib = "database/sql"
36+ gormLib = "github.com/jinzhu/gorm"
37+ goGorpLib = "go-gorp/gorp"
38+ gorpV1Lib = "gopkg.in/gorp.v1"
39+
40+ queryArgName = "query"
41+ sqlArgName = "sql"
42+
43+ rebindMethodName = "Rebind"
44+ rebindxMethodName = "Rebindx"
45+ )
46+
3347type QuerySite struct {
3448 Called string
3549 Position token.Position
@@ -100,10 +114,17 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
100114 sqlfuncs := []MatchedSqlFunc {}
101115
102116 s .IterPackageExportedFuncs (func (fobj * types.Func ) {
117+ ssaFunc := prog .FuncValue (fobj )
118+
119+ // Skip pass-through functions that shouldn't be validated as SQL functions
120+ if isPassThroughFunc (ssaFunc ) {
121+ return
122+ }
123+
103124 for _ , rule := range s .Rules {
104125 if rule .FuncName != "" && fobj .Name () == rule .FuncName {
105126 sqlfuncs = append (sqlfuncs , MatchedSqlFunc {
106- SSA : prog . FuncValue ( fobj ) ,
127+ SSA : ssaFunc ,
107128 QueryArgPos : rule .QueryArgPos ,
108129 })
109130 // callable matched one rule, no need to go through the rest
@@ -120,7 +141,7 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
120141 continue
121142 }
122143 sqlfuncs = append (sqlfuncs , MatchedSqlFunc {
123- SSA : prog . FuncValue ( fobj ) ,
144+ SSA : ssaFunc ,
124145 QueryArgPos : rule .QueryArgPos ,
125146 })
126147 // callable matched one rule, no need to go through the rest
@@ -132,13 +153,29 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
132153 return sqlfuncs
133154}
134155
156+ // isNamedQueryFunc checks if a function name is a "named query" function
157+ // that expects named parameters (like :param) instead of positional ($1, $2)
158+ func isNamedQueryFunc (funcName string ) bool {
159+ // Check for sqlx named query functions
160+ switch funcName {
161+ case "NamedExec" , "NamedQuery" , "NamedExecContext" , "NamedQueryContext" ,
162+ "NamedQueryRow" , "NamedQueryRowContext" :
163+ return true
164+ }
165+ // Also check if the function name contains "Named" (catches custom wrappers)
166+ return strings .Contains (funcName , "Named" )
167+ }
168+
135169func handleQuery (ctx VetContext , qs * QuerySite ) {
136- // TODO: apply named query resolution based on v.X type and v.Sel.Name
137- // e.g. for sqlx, only apply to NamedExec and NamedQuery
138- qs .Query , _ , qs .Err = parseutil .CompileNamedQuery (
139- []byte (qs .Query ), parseutil .BindType ("postgres" ))
140- if qs .Err != nil {
141- return
170+ // Only apply named query resolution for named query functions
171+ // (e.g., NamedExec, NamedQuery, NamedExecContext, NamedQueryContext)
172+ // to avoid breaking PostgreSQL type casts (::) in regular queries
173+ if isNamedQueryFunc (qs .Called ) {
174+ qs .Query , _ , qs .Err = parseutil .CompileNamedQuery (
175+ []byte (qs .Query ), parseutil .BindType ("postgres" ))
176+ if qs .Err != nil {
177+ return
178+ }
142179 }
143180
144181 var queryParams []QueryParam
@@ -160,31 +197,31 @@ func handleQuery(ctx VetContext, qs *QuerySite) {
160197func getMatchers (extraMatchers []SqlFuncMatcher ) []* SqlFuncMatcher {
161198 matchers := []* SqlFuncMatcher {
162199 {
163- PkgPath : "github.com/jmoiron/sqlx" ,
200+ PkgPath : sqlxLib ,
164201 Rules : []SqlFuncMatchRule {
165- {QueryArgName : "query" },
166- {QueryArgName : "sql" },
202+ {QueryArgName : queryArgName },
203+ {QueryArgName : sqlArgName },
167204 // for methods with Context suffix
168- {QueryArgName : "query" , QueryArgPos : 1 },
169- {QueryArgName : "sql" , QueryArgPos : 1 },
170- {QueryArgName : "query" , QueryArgPos : 2 },
171- {QueryArgName : "sql" , QueryArgPos : 2 },
205+ {QueryArgName : queryArgName , QueryArgPos : 1 },
206+ {QueryArgName : sqlArgName , QueryArgPos : 1 },
207+ {QueryArgName : queryArgName , QueryArgPos : 2 },
208+ {QueryArgName : sqlArgName , QueryArgPos : 2 },
172209 },
173210 },
174211 {
175- PkgPath : "database/sql" ,
212+ PkgPath : dbSqlLib ,
176213 Rules : []SqlFuncMatchRule {
177- {QueryArgName : "query" },
178- {QueryArgName : "sql" },
214+ {QueryArgName : queryArgName },
215+ {QueryArgName : sqlArgName },
179216 // for methods with Context suffix
180- {QueryArgName : "query" , QueryArgPos : 1 },
181- {QueryArgName : "sql" , QueryArgPos : 1 },
217+ {QueryArgName : queryArgName , QueryArgPos : 1 },
218+ {QueryArgName : sqlArgName , QueryArgPos : 1 },
182219 },
183220 },
184221 {
185- PkgPath : "github.com/jinzhu/gorm" ,
222+ PkgPath : gormLib ,
186223 Rules : []SqlFuncMatchRule {
187- {QueryArgName : "sql" },
224+ {QueryArgName : sqlArgName },
188225 },
189226 },
190227 // TODO: xorm uses vararg, which is not supported yet
@@ -201,15 +238,15 @@ func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
201238 // },
202239 // },
203240 {
204- PkgPath : "go-gorp/gorp" ,
241+ PkgPath : goGorpLib ,
205242 Rules : []SqlFuncMatchRule {
206- {QueryArgName : "query" },
243+ {QueryArgName : queryArgName },
207244 },
208245 },
209246 {
210- PkgPath : "gopkg.in/gorp.v1" ,
247+ PkgPath : gorpV1Lib ,
211248 Rules : []SqlFuncMatchRule {
212- {QueryArgName : "query" },
249+ {QueryArgName : queryArgName },
213250 },
214251 },
215252 }
@@ -240,7 +277,7 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
240277 }
241278 dirAbs , err := filepath .Abs (dir )
242279 if err != nil {
243- return nil , fmt .Errorf ("Invalid path: %w" , err )
280+ return nil , fmt .Errorf ("invalid path: %w" , err )
244281 }
245282 pkgPath := dirAbs + "/..."
246283 pkgs , err := packages .Load (cfg , pkgPath )
@@ -250,12 +287,54 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
250287 // return early if any syntax error
251288 for _ , pkg := range pkgs {
252289 if len (pkg .Errors ) > 0 {
253- return nil , fmt .Errorf ("Failed to load package, %w" , pkg .Errors [0 ])
290+ return nil , fmt .Errorf ("failed to load package, %w" , pkg .Errors [0 ])
254291 }
255292 }
256293 return pkgs , nil
257294}
258295
296+ // isPassThroughMethodName checks if a method name is known to be a pass-through
297+ func isPassThroughMethodName (methodName string ) bool {
298+ switch methodName {
299+ case rebindMethodName , rebindxMethodName :
300+ return true
301+ }
302+ return false
303+ }
304+
305+ // isPassThroughFunc checks if a function is known to be a pass-through
306+ // that transforms query syntax without changing semantic meaning
307+ func isPassThroughFunc (fn * ssa.Function ) bool {
308+ if fn == nil {
309+ return false
310+ }
311+
312+ // Get the package path and function name
313+ if fn .Pkg != nil && fn .Pkg .Pkg != nil {
314+ pkgPath := fn .Pkg .Pkg .Path ()
315+ funcName := fn .Name ()
316+
317+ // sqlx package pass-through functions
318+ if pkgPath == sqlxLib && isPassThroughMethodName (funcName ) {
319+ return true
320+ }
321+ }
322+
323+ // Check by receiver type for methods
324+ if fn .Signature .Recv () != nil {
325+ recv := fn .Signature .Recv ()
326+ recvType := recv .Type ().String ()
327+ funcName := fn .Name ()
328+
329+ // sqlx methods that are pass-through
330+ if strings .HasPrefix (recvType , sqlxLib + "." ) && isPassThroughMethodName (funcName ) {
331+ return true
332+ }
333+ }
334+
335+ return false
336+ }
337+
259338func extractQueryStrFromSsaValue (argVal ssa.Value ) (string , error ) {
260339 queryStr := ""
261340
@@ -292,11 +371,95 @@ func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
292371 return "" , ErrQueryArgTODO
293372 case * ssa.Extract :
294373 // query string is from one of the multi return values
295- // need to figure out how to trace string from function returns
374+ // Try to trace the source of the multi-value return
375+ if queryArg .Tuple == nil {
376+ return "" , ErrQueryArgTODO
377+ }
378+
379+ // Check if the tuple comes from a function call
380+ if call , ok := queryArg .Tuple .(* ssa.Call ); ok {
381+ callee := call .Call .StaticCallee ()
382+ if callee == nil {
383+ return "" , ErrQueryArgTODO
384+ }
385+
386+ // Check if the function has a body
387+ if len (callee .Blocks ) == 0 {
388+ // External function, can't trace further
389+ return "" , ErrQueryArgTODO
390+ }
391+
392+ // Look for return instructions and extract the specific index
393+ for _ , block := range callee .Blocks {
394+ for _ , instr := range block .Instrs {
395+ if ret , ok := instr .(* ssa.Return ); ok {
396+ if queryArg .Index >= len (ret .Results ) {
397+ continue
398+ }
399+ // Extract the query string from the specific return value at this index
400+ return extractQueryStrFromSsaValue (ret .Results [queryArg .Index ])
401+ }
402+ }
403+ }
404+ }
405+
296406 return "" , ErrQueryArgTODO
297407 case * ssa.Call :
298408 // return value from a function call
299- // TODO: trace caller function
409+ // Try to trace the function to extract the query string
410+ callee := queryArg .Call .StaticCallee ()
411+
412+ // Check if this is a known pass-through function call
413+ // For interface calls, callee will be nil, so we check by method name
414+ if callee == nil {
415+ // Dynamic call (interface method, function value, etc.)
416+ // Check if it's a known pass-through method by name
417+ if queryArg .Call .IsInvoke () {
418+ method := queryArg .Call .Method
419+ if method != nil && isPassThroughMethodName (method .Name ()) {
420+ // Extract the query from the first argument
421+ callArgs := queryArg .Call .Args
422+ if len (callArgs ) > 0 {
423+ return extractQueryStrFromSsaValue (callArgs [0 ])
424+ }
425+ }
426+ }
427+ return "" , ErrQueryArgUnsafe
428+ }
429+
430+ // Handle known pass-through functions that just transform the query
431+ // without changing its semantic meaning (e.g., sqlx.Rebind)
432+ if isPassThroughFunc (callee ) {
433+ // Extract the query from the first argument
434+ callArgs := queryArg .Call .Args
435+ if len (callArgs ) > 0 {
436+ // For method calls, the receiver is not in Args, so Args[0] is the first parameter
437+ return extractQueryStrFromSsaValue (callArgs [0 ])
438+ }
439+ return "" , ErrQueryArgUnsafe
440+ }
441+
442+ // Check if the function has a body (not external or builtin)
443+ if len (callee .Blocks ) == 0 {
444+ return "" , ErrQueryArgUnsafe
445+ }
446+
447+ // Look for return instructions in the function
448+ // This handles simple cases where the function returns a constant or computed value
449+ for _ , block := range callee .Blocks {
450+ for _ , instr := range block .Instrs {
451+ if ret , ok := instr .(* ssa.Return ); ok {
452+ if len (ret .Results ) == 0 {
453+ continue
454+ }
455+ // Recursively extract the query string from the first return value
456+ // This handles cases like:
457+ // func getQuery() string { return "SELECT * FROM users" }
458+ return extractQueryStrFromSsaValue (ret .Results [0 ])
459+ }
460+ }
461+ }
462+
300463 return "" , ErrQueryArgUnsafe
301464 case * ssa.MakeInterface :
302465 // query function takes interface as input
@@ -346,7 +509,7 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool {
346509}
347510
348511func iterCallGraphNodeCallees (ctx VetContext , cgNode * callgraph.Node , prog * ssa.Program , sqlfunc MatchedSqlFunc , ignoreNodes []ast.Node ) []* QuerySite {
349- queries := []* QuerySite {}
512+ var queries []* QuerySite
350513
351514 for _ , inEdge := range cgNode .In {
352515 callerFunc := inEdge .Caller .Func
@@ -490,7 +653,7 @@ func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node {
490653func CheckDir (ctx VetContext , dir , buildFlags string , extraMatchers []SqlFuncMatcher ) ([]* QuerySite , error ) {
491654 _ , err := os .Stat (filepath .Join (dir , "go.mod" ))
492655 if os .IsNotExist (err ) {
493- return nil , errors .New ("sqlvet only supports projects using go modules for now. " )
656+ return nil , errors .New ("sqlvet only supports projects using go modules for now" )
494657 }
495658
496659 pkgs , err := loadGoPackages (dir , buildFlags )
@@ -520,11 +683,11 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
520683
521684 mode := ssa .InstantiateGenerics
522685 prog , ssaPkgs := ssautil .Packages (pkgs , mode )
523- log .Debug ("Performaing whole-program analysis..." )
686+ log .Debug ("Performing whole-program analysis..." )
524687 prog .Build ()
525688
526689 // find ssa.Function for matched sqlfuncs from program
527- sqlfuncs := []MatchedSqlFunc {}
690+ var sqlfuncs []MatchedSqlFunc
528691 for _ , matcher := range matchers {
529692 if ! matcher .PackageImported () {
530693 // if package is not imported, then no sqlfunc should be matched
@@ -538,7 +701,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
538701 mains := ssautil .MainPackages (ssaPkgs )
539702
540703 log .Debug ("Building call graph..." )
541- funcs := []* ssa.Function {}
704+ var funcs []* ssa.Function
542705 for _ , fn := range mains {
543706 if main := fn .Func ("main" ); main != nil {
544707 funcs = append (funcs , main )
@@ -553,7 +716,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
553716 return nil , nil
554717 }
555718
556- queries := []* QuerySite {}
719+ var queries []* QuerySite
557720 cg := rtaRes .CallGraph
558721 for _ , sqlfunc := range sqlfuncs {
559722 cgNode := cg .CreateNode (sqlfunc .SSA )
0 commit comments