Skip to content

Commit c708d01

Browse files
daniil-pankratovDaniil Pankratov
andauthored
Add Postgres case, cast operator and rebind support (#47)
Co-authored-by: Daniil Pankratov <daniil.pankratov@finteqhub.com>
1 parent 0969590 commit c708d01

File tree

5 files changed

+401
-45
lines changed

5 files changed

+401
-45
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ on:
88

99
jobs:
1010
test:
11-
runs-on: ubuntu-latest
11+
strategy:
12+
matrix:
13+
os: [ubuntu-latest, macos-latest]
14+
runs-on: ${{ matrix.os }}
1215
steps:
1316
- uses: actions/checkout@v2
1417
- uses: actions/setup-go@v2

main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/houqp/sqlvet/pkg/vet"
1616
)
1717

18-
const version = "1.1.10"
18+
const version = "1.1.11"
1919

2020
var (
2121
gitCommit = "?"

pkg/vet/gosource.go

Lines changed: 199 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ var (
3030
ErrQueryArgTODO = errors.New("TODO: support this type")
3131
)
3232

33+
const (
34+
sqlxLib = "github.com/jmoiron/sqlx"
35+
dbSqlLib = "database/sql"
36+
gormLib = "github.com/jinzhu/gorm"
37+
goGorpLib = "go-gorp/gorp"
38+
gorpV1Lib = "gopkg.in/gorp.v1"
39+
40+
queryArgName = "query"
41+
sqlArgName = "sql"
42+
43+
rebindMethodName = "Rebind"
44+
rebindxMethodName = "Rebindx"
45+
)
46+
3347
type QuerySite struct {
3448
Called string
3549
Position token.Position
@@ -100,10 +114,17 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
100114
sqlfuncs := []MatchedSqlFunc{}
101115

102116
s.IterPackageExportedFuncs(func(fobj *types.Func) {
117+
ssaFunc := prog.FuncValue(fobj)
118+
119+
// Skip pass-through functions that shouldn't be validated as SQL functions
120+
if isPassThroughFunc(ssaFunc) {
121+
return
122+
}
123+
103124
for _, rule := range s.Rules {
104125
if rule.FuncName != "" && fobj.Name() == rule.FuncName {
105126
sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
106-
SSA: prog.FuncValue(fobj),
127+
SSA: ssaFunc,
107128
QueryArgPos: rule.QueryArgPos,
108129
})
109130
// callable matched one rule, no need to go through the rest
@@ -120,7 +141,7 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
120141
continue
121142
}
122143
sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
123-
SSA: prog.FuncValue(fobj),
144+
SSA: ssaFunc,
124145
QueryArgPos: rule.QueryArgPos,
125146
})
126147
// callable matched one rule, no need to go through the rest
@@ -132,13 +153,29 @@ func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
132153
return sqlfuncs
133154
}
134155

156+
// isNamedQueryFunc checks if a function name is a "named query" function
157+
// that expects named parameters (like :param) instead of positional ($1, $2)
158+
func isNamedQueryFunc(funcName string) bool {
159+
// Check for sqlx named query functions
160+
switch funcName {
161+
case "NamedExec", "NamedQuery", "NamedExecContext", "NamedQueryContext",
162+
"NamedQueryRow", "NamedQueryRowContext":
163+
return true
164+
}
165+
// Also check if the function name contains "Named" (catches custom wrappers)
166+
return strings.Contains(funcName, "Named")
167+
}
168+
135169
func handleQuery(ctx VetContext, qs *QuerySite) {
136-
// TODO: apply named query resolution based on v.X type and v.Sel.Name
137-
// e.g. for sqlx, only apply to NamedExec and NamedQuery
138-
qs.Query, _, qs.Err = parseutil.CompileNamedQuery(
139-
[]byte(qs.Query), parseutil.BindType("postgres"))
140-
if qs.Err != nil {
141-
return
170+
// Only apply named query resolution for named query functions
171+
// (e.g., NamedExec, NamedQuery, NamedExecContext, NamedQueryContext)
172+
// to avoid breaking PostgreSQL type casts (::) in regular queries
173+
if isNamedQueryFunc(qs.Called) {
174+
qs.Query, _, qs.Err = parseutil.CompileNamedQuery(
175+
[]byte(qs.Query), parseutil.BindType("postgres"))
176+
if qs.Err != nil {
177+
return
178+
}
142179
}
143180

144181
var queryParams []QueryParam
@@ -160,31 +197,31 @@ func handleQuery(ctx VetContext, qs *QuerySite) {
160197
func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
161198
matchers := []*SqlFuncMatcher{
162199
{
163-
PkgPath: "github.com/jmoiron/sqlx",
200+
PkgPath: sqlxLib,
164201
Rules: []SqlFuncMatchRule{
165-
{QueryArgName: "query"},
166-
{QueryArgName: "sql"},
202+
{QueryArgName: queryArgName},
203+
{QueryArgName: sqlArgName},
167204
// for methods with Context suffix
168-
{QueryArgName: "query", QueryArgPos: 1},
169-
{QueryArgName: "sql", QueryArgPos: 1},
170-
{QueryArgName: "query", QueryArgPos: 2},
171-
{QueryArgName: "sql", QueryArgPos: 2},
205+
{QueryArgName: queryArgName, QueryArgPos: 1},
206+
{QueryArgName: sqlArgName, QueryArgPos: 1},
207+
{QueryArgName: queryArgName, QueryArgPos: 2},
208+
{QueryArgName: sqlArgName, QueryArgPos: 2},
172209
},
173210
},
174211
{
175-
PkgPath: "database/sql",
212+
PkgPath: dbSqlLib,
176213
Rules: []SqlFuncMatchRule{
177-
{QueryArgName: "query"},
178-
{QueryArgName: "sql"},
214+
{QueryArgName: queryArgName},
215+
{QueryArgName: sqlArgName},
179216
// for methods with Context suffix
180-
{QueryArgName: "query", QueryArgPos: 1},
181-
{QueryArgName: "sql", QueryArgPos: 1},
217+
{QueryArgName: queryArgName, QueryArgPos: 1},
218+
{QueryArgName: sqlArgName, QueryArgPos: 1},
182219
},
183220
},
184221
{
185-
PkgPath: "github.com/jinzhu/gorm",
222+
PkgPath: gormLib,
186223
Rules: []SqlFuncMatchRule{
187-
{QueryArgName: "sql"},
224+
{QueryArgName: sqlArgName},
188225
},
189226
},
190227
// TODO: xorm uses vararg, which is not supported yet
@@ -201,15 +238,15 @@ func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
201238
// },
202239
// },
203240
{
204-
PkgPath: "go-gorp/gorp",
241+
PkgPath: goGorpLib,
205242
Rules: []SqlFuncMatchRule{
206-
{QueryArgName: "query"},
243+
{QueryArgName: queryArgName},
207244
},
208245
},
209246
{
210-
PkgPath: "gopkg.in/gorp.v1",
247+
PkgPath: gorpV1Lib,
211248
Rules: []SqlFuncMatchRule{
212-
{QueryArgName: "query"},
249+
{QueryArgName: queryArgName},
213250
},
214251
},
215252
}
@@ -240,7 +277,7 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
240277
}
241278
dirAbs, err := filepath.Abs(dir)
242279
if err != nil {
243-
return nil, fmt.Errorf("Invalid path: %w", err)
280+
return nil, fmt.Errorf("invalid path: %w", err)
244281
}
245282
pkgPath := dirAbs + "/..."
246283
pkgs, err := packages.Load(cfg, pkgPath)
@@ -250,12 +287,54 @@ func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error)
250287
// return early if any syntax error
251288
for _, pkg := range pkgs {
252289
if len(pkg.Errors) > 0 {
253-
return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0])
290+
return nil, fmt.Errorf("failed to load package, %w", pkg.Errors[0])
254291
}
255292
}
256293
return pkgs, nil
257294
}
258295

296+
// isPassThroughMethodName checks if a method name is known to be a pass-through
297+
func isPassThroughMethodName(methodName string) bool {
298+
switch methodName {
299+
case rebindMethodName, rebindxMethodName:
300+
return true
301+
}
302+
return false
303+
}
304+
305+
// isPassThroughFunc checks if a function is known to be a pass-through
306+
// that transforms query syntax without changing semantic meaning
307+
func isPassThroughFunc(fn *ssa.Function) bool {
308+
if fn == nil {
309+
return false
310+
}
311+
312+
// Get the package path and function name
313+
if fn.Pkg != nil && fn.Pkg.Pkg != nil {
314+
pkgPath := fn.Pkg.Pkg.Path()
315+
funcName := fn.Name()
316+
317+
// sqlx package pass-through functions
318+
if pkgPath == sqlxLib && isPassThroughMethodName(funcName) {
319+
return true
320+
}
321+
}
322+
323+
// Check by receiver type for methods
324+
if fn.Signature.Recv() != nil {
325+
recv := fn.Signature.Recv()
326+
recvType := recv.Type().String()
327+
funcName := fn.Name()
328+
329+
// sqlx methods that are pass-through
330+
if strings.HasPrefix(recvType, sqlxLib+".") && isPassThroughMethodName(funcName) {
331+
return true
332+
}
333+
}
334+
335+
return false
336+
}
337+
259338
func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
260339
queryStr := ""
261340

@@ -292,11 +371,95 @@ func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
292371
return "", ErrQueryArgTODO
293372
case *ssa.Extract:
294373
// query string is from one of the multi return values
295-
// need to figure out how to trace string from function returns
374+
// Try to trace the source of the multi-value return
375+
if queryArg.Tuple == nil {
376+
return "", ErrQueryArgTODO
377+
}
378+
379+
// Check if the tuple comes from a function call
380+
if call, ok := queryArg.Tuple.(*ssa.Call); ok {
381+
callee := call.Call.StaticCallee()
382+
if callee == nil {
383+
return "", ErrQueryArgTODO
384+
}
385+
386+
// Check if the function has a body
387+
if len(callee.Blocks) == 0 {
388+
// External function, can't trace further
389+
return "", ErrQueryArgTODO
390+
}
391+
392+
// Look for return instructions and extract the specific index
393+
for _, block := range callee.Blocks {
394+
for _, instr := range block.Instrs {
395+
if ret, ok := instr.(*ssa.Return); ok {
396+
if queryArg.Index >= len(ret.Results) {
397+
continue
398+
}
399+
// Extract the query string from the specific return value at this index
400+
return extractQueryStrFromSsaValue(ret.Results[queryArg.Index])
401+
}
402+
}
403+
}
404+
}
405+
296406
return "", ErrQueryArgTODO
297407
case *ssa.Call:
298408
// return value from a function call
299-
// TODO: trace caller function
409+
// Try to trace the function to extract the query string
410+
callee := queryArg.Call.StaticCallee()
411+
412+
// Check if this is a known pass-through function call
413+
// For interface calls, callee will be nil, so we check by method name
414+
if callee == nil {
415+
// Dynamic call (interface method, function value, etc.)
416+
// Check if it's a known pass-through method by name
417+
if queryArg.Call.IsInvoke() {
418+
method := queryArg.Call.Method
419+
if method != nil && isPassThroughMethodName(method.Name()) {
420+
// Extract the query from the first argument
421+
callArgs := queryArg.Call.Args
422+
if len(callArgs) > 0 {
423+
return extractQueryStrFromSsaValue(callArgs[0])
424+
}
425+
}
426+
}
427+
return "", ErrQueryArgUnsafe
428+
}
429+
430+
// Handle known pass-through functions that just transform the query
431+
// without changing its semantic meaning (e.g., sqlx.Rebind)
432+
if isPassThroughFunc(callee) {
433+
// Extract the query from the first argument
434+
callArgs := queryArg.Call.Args
435+
if len(callArgs) > 0 {
436+
// For method calls, the receiver is not in Args, so Args[0] is the first parameter
437+
return extractQueryStrFromSsaValue(callArgs[0])
438+
}
439+
return "", ErrQueryArgUnsafe
440+
}
441+
442+
// Check if the function has a body (not external or builtin)
443+
if len(callee.Blocks) == 0 {
444+
return "", ErrQueryArgUnsafe
445+
}
446+
447+
// Look for return instructions in the function
448+
// This handles simple cases where the function returns a constant or computed value
449+
for _, block := range callee.Blocks {
450+
for _, instr := range block.Instrs {
451+
if ret, ok := instr.(*ssa.Return); ok {
452+
if len(ret.Results) == 0 {
453+
continue
454+
}
455+
// Recursively extract the query string from the first return value
456+
// This handles cases like:
457+
// func getQuery() string { return "SELECT * FROM users" }
458+
return extractQueryStrFromSsaValue(ret.Results[0])
459+
}
460+
}
461+
}
462+
300463
return "", ErrQueryArgUnsafe
301464
case *ssa.MakeInterface:
302465
// query function takes interface as input
@@ -346,7 +509,7 @@ func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool {
346509
}
347510

348511
func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite {
349-
queries := []*QuerySite{}
512+
var queries []*QuerySite
350513

351514
for _, inEdge := range cgNode.In {
352515
callerFunc := inEdge.Caller.Func
@@ -490,7 +653,7 @@ func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node {
490653
func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) {
491654
_, err := os.Stat(filepath.Join(dir, "go.mod"))
492655
if os.IsNotExist(err) {
493-
return nil, errors.New("sqlvet only supports projects using go modules for now.")
656+
return nil, errors.New("sqlvet only supports projects using go modules for now")
494657
}
495658

496659
pkgs, err := loadGoPackages(dir, buildFlags)
@@ -520,11 +683,11 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
520683

521684
mode := ssa.InstantiateGenerics
522685
prog, ssaPkgs := ssautil.Packages(pkgs, mode)
523-
log.Debug("Performaing whole-program analysis...")
686+
log.Debug("Performing whole-program analysis...")
524687
prog.Build()
525688

526689
// find ssa.Function for matched sqlfuncs from program
527-
sqlfuncs := []MatchedSqlFunc{}
690+
var sqlfuncs []MatchedSqlFunc
528691
for _, matcher := range matchers {
529692
if !matcher.PackageImported() {
530693
// if package is not imported, then no sqlfunc should be matched
@@ -538,7 +701,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
538701
mains := ssautil.MainPackages(ssaPkgs)
539702

540703
log.Debug("Building call graph...")
541-
funcs := []*ssa.Function{}
704+
var funcs []*ssa.Function
542705
for _, fn := range mains {
543706
if main := fn.Func("main"); main != nil {
544707
funcs = append(funcs, main)
@@ -553,7 +716,7 @@ func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMat
553716
return nil, nil
554717
}
555718

556-
queries := []*QuerySite{}
719+
var queries []*QuerySite
557720
cg := rtaRes.CallGraph
558721
for _, sqlfunc := range sqlfuncs {
559722
cgNode := cg.CreateNode(sqlfunc.SSA)

0 commit comments

Comments
 (0)