Skip to content

Commit b92fc52

Browse files
1 parent ad704b9 commit b92fc52

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

filter/converter.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@ var basicOperatorMap = map[string]string{
1919
"$regex": "~*",
2020
}
2121

22+
// DefaultPlaceholderName is the default placeholder name used in the generated SQL query.
23+
// This name should not be used in the database or any JSONB column. It can be changed using
24+
// the WithPlaceholderName option.
25+
const DefaultPlaceholderName = "__filter_placeholder"
26+
2227
type Converter struct {
2328
nestedColumn string
2429
nestedExemptions []string
2530
arrayDriver func(a any) interface {
2631
driver.Valuer
2732
sql.Scanner
2833
}
29-
emptyCondition string
34+
emptyCondition string
35+
placeholderName string
3036
}
3137

3238
// NewConverter creates a new Converter with optional nested JSONB field mapping.
@@ -41,6 +47,9 @@ func NewConverter(options ...Option) *Converter {
4147
option(converter)
4248
}
4349
}
50+
if converter.placeholderName == "" {
51+
converter.placeholderName = DefaultPlaceholderName
52+
}
4453
return converter
4554
}
4655

@@ -197,6 +206,35 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
197206
neg = "NOT "
198207
}
199208
inner = append(inner, fmt.Sprintf("(%sjsonb_path_match(%s, 'exists($.%s)'))", neg, c.nestedColumn, key))
209+
case "$elemMatch":
210+
// $elemMatch needs a different implementation depending on if the column is in JSONB or not.
211+
isNestedColumn := c.nestedColumn != ""
212+
for _, exemption := range c.nestedExemptions {
213+
if exemption == key {
214+
isNestedColumn = false
215+
break
216+
}
217+
}
218+
innerConditions, innerValues, err := c.convertFilter(map[string]any{c.placeholderName: v[operator]}, paramIndex)
219+
if err != nil {
220+
return "", nil, err
221+
}
222+
paramIndex += len(innerValues)
223+
if isNestedColumn {
224+
// This will for example become:
225+
//
226+
// EXISTS (SELECT 1 FROM jsonb_array_elements("meta"->'foo') AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))
227+
//
228+
// We can't use c.columnName here because we need `->` to get the jsonb value instead of `->>` which gets the text value.
229+
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM jsonb_array_elements(%q->'%s') AS %s WHERE %s)", c.nestedColumn, key, c.placeholderName, innerConditions))
230+
} else {
231+
// This will for example become:
232+
//
233+
// EXISTS (SELECT 1 FROM unnest("foo") AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))
234+
//
235+
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key), c.placeholderName, innerConditions))
236+
}
237+
values = append(values, innerValues...)
200238
default:
201239
value := v[operator]
202240
op, ok := basicOperatorMap[operator]
@@ -247,6 +285,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
247285
}
248286

249287
func (c *Converter) columnName(column string) string {
288+
if column == c.placeholderName {
289+
return fmt.Sprintf(`%q::text`, column)
290+
}
250291
if c.nestedColumn == "" {
251292
return fmt.Sprintf("%q", column)
252293
}

filter/converter_test.go

+32
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,38 @@ func TestConverter_Convert(t *testing.T) {
304304
nil,
305305
nil,
306306
},
307+
{
308+
"sql injection",
309+
nil,
310+
`{"\"bla = 1 --": 1}`,
311+
``,
312+
nil,
313+
fmt.Errorf("invalid column name: \"bla = 1 --"),
314+
},
315+
{
316+
"$elemMatch on normal column",
317+
nil,
318+
`{"name": {"$elemMatch": {"$eq": "John"}}}`,
319+
`EXISTS (SELECT 1 FROM unnest("name") AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))`,
320+
[]any{"John"},
321+
nil,
322+
},
323+
{
324+
"$elemMatch on jsonb column",
325+
filter.WithNestedJSONB("meta"),
326+
`{"name": {"$elemMatch": {"$eq": "John"}}}`,
327+
`EXISTS (SELECT 1 FROM jsonb_array_elements("meta"->'name') AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))`,
328+
[]any{"John"},
329+
nil,
330+
},
331+
{
332+
"$elemMatch with $gt",
333+
filter.WithPlaceholderName("__placeholder"),
334+
`{"age": {"$elemMatch": {"$gt": 18}}}`,
335+
`EXISTS (SELECT 1 FROM unnest("age") AS __placeholder WHERE ("__placeholder"::text > $1))`,
336+
[]any{float64(18)},
337+
nil,
338+
},
307339
}
308340

309341
for _, tt := range tests {

filter/options.go

+9
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,12 @@ func WithEmptyCondition(condition string) Option {
4949
c.emptyCondition = condition
5050
}
5151
}
52+
53+
// WithPlaceholderName is an option to specify the placeholder name that will be
54+
// used in the generated SQL query. This name should not be used in the database
55+
// or any JSONB column.
56+
func WithPlaceholderName(name string) Option {
57+
return func(c *Converter) {
58+
c.placeholderName = name
59+
}
60+
}

fuzz/fuzz_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ func FuzzConverter(f *testing.F) {
4242
`{"name": {"$not": {"$eq": "John"}}}`,
4343
`{"name": null}`,
4444
`{"name": {"$exists": false}}`,
45+
`{"name": {"$elemMatch": {"$eq": "John"}}}`,
46+
`{"age": {"$elemMatch": {"$gt": 18}}}`,
4547
}
4648
for _, tc := range tcs {
4749
f.Add(tc, true)

integration/postgres_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,30 @@ func TestIntegration_BasicOperators(t *testing.T) {
357357
[]int{1, 5, 6, 7, 8, 9, 10},
358358
nil,
359359
},
360+
{
361+
"$elemMatch on normal column",
362+
`{"items": {"$elemMatch": {"$regex": "a"}}}`,
363+
[]int{5, 6},
364+
nil,
365+
},
366+
{
367+
"$elemMatch on jsonb column",
368+
`{"hats": {"$elemMatch": {"$regex": "a"}}}`,
369+
[]int{6},
370+
nil,
371+
},
372+
{
373+
"$elemMatch with a numeric column",
374+
`{"parents": {"$elemMatch": {"$gt": 40, "$lt": 60}}}`,
375+
[]int{3},
376+
nil,
377+
},
378+
{
379+
"$elemMatch with numeric jsonb column",
380+
`{"keys": {"$elemMatch": {"$gt": 5}}}`,
381+
[]int{3},
382+
nil,
383+
},
360384
}
361385

362386
for _, tt := range tests {

0 commit comments

Comments
 (0)