Skip to content

Commit 9a306a7

Browse files
committedFeb 7, 2025··
Fix handling multiple statements in a single SQL query
1 parent fe7e7b1 commit 9a306a7

File tree

5 files changed

+95
-46
lines changed

5 files changed

+95
-46
lines changed
 

‎img/architecture.png

-498 Bytes
Loading

‎src/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"time"
77
)
88

9-
const VERSION = "0.30.3"
9+
const VERSION = "0.31.4"
1010

1111
func main() {
1212
config := LoadConfig()

‎src/query_handler.go

+75-42
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ import (
1919
)
2020

2121
const (
22-
FALLBACK_SQL_QUERY = "SELECT 1"
22+
FALLBACK_SQL_QUERY = "SELECT 1"
23+
INSPECT_SQL_COMMENT = " --INSPECT"
2324
)
2425

2526
type QueryHandler struct {
@@ -199,64 +200,78 @@ func NewQueryHandler(config *Config, duckdb *Duckdb, icebergReader *IcebergReade
199200
}
200201

201202
func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.Message, error) {
202-
query, err := queryHandler.remapQuery(originalQuery)
203+
queryStatements, originalQueryStatements, err := queryHandler.parseAndRemapQuery(originalQuery)
203204
if err != nil {
204205
LogError(queryHandler.config, "Couldn't map query:", originalQuery+"\n"+err.Error())
205206
return nil, err
206207
}
207-
208-
if query == "" {
208+
if len(queryStatements) == 0 {
209209
return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, nil
210210
}
211211

212-
rows, err := queryHandler.duckdb.QueryContext(context.Background(), query)
213-
if err != nil {
214-
errorMessage := err.Error()
215-
216-
if errorMessage == "Binder Error: UNNEST requires a single list as input" {
217-
// https://github.com/duckdb/duckdb/issues/11693
218-
LogWarn(queryHandler.config, "Couldn't handle query via DuckDB:", query+"\n"+err.Error())
219-
return queryHandler.HandleQuery(FALLBACK_SQL_QUERY)
220-
} else {
221-
LogError(queryHandler.config, "Couldn't handle query via DuckDB:", query+"\n"+err.Error())
212+
var queriesMessages []pgproto3.Message
213+
214+
for i, queryStatement := range queryStatements {
215+
rows, err := queryHandler.duckdb.QueryContext(context.Background(), queryStatement)
216+
if err != nil {
217+
errorMessage := err.Error()
218+
if errorMessage == "Binder Error: UNNEST requires a single list as input" {
219+
// https://github.com/duckdb/duckdb/issues/11693
220+
LogWarn(queryHandler.config, "Couldn't handle query via DuckDB:", queryStatement+"\n"+err.Error())
221+
queriesMsgs, err := queryHandler.HandleQuery(FALLBACK_SQL_QUERY) // self-recursion
222+
if err != nil {
223+
return nil, err
224+
}
225+
queriesMessages = append(queriesMessages, queriesMsgs...)
226+
continue
227+
} else {
228+
LogError(queryHandler.config, "Couldn't handle query via DuckDB:", queryStatement+"\n"+err.Error())
229+
return nil, err
230+
}
231+
}
232+
defer rows.Close()
233+
234+
var queryMessages []pgproto3.Message
235+
descriptionMessages, err := queryHandler.rowsToDescriptionMessages(rows, queryStatement)
236+
if err != nil {
222237
return nil, err
223238
}
224-
}
225-
defer rows.Close()
239+
queryMessages = append(queryMessages, descriptionMessages...)
240+
dataMessages, err := queryHandler.rowsToDataMessages(rows, originalQueryStatements[i])
241+
if err != nil {
242+
return nil, err
243+
}
244+
queryMessages = append(queryMessages, dataMessages...)
226245

227-
var messages []pgproto3.Message
228-
descriptionMessages, err := queryHandler.rowsToDescriptionMessages(rows, query)
229-
if err != nil {
230-
return nil, err
246+
queriesMessages = append(queriesMessages, queryMessages...)
231247
}
232-
messages = append(messages, descriptionMessages...)
233-
dataMessages, err := queryHandler.rowsToDataMessages(rows, originalQuery)
234-
if err != nil {
235-
return nil, err
236-
}
237-
messages = append(messages, dataMessages...)
238-
return messages, nil
248+
249+
return queriesMessages, nil
239250
}
240251

241252
func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]pgproto3.Message, *PreparedStatement, error) {
242253
ctx := context.Background()
243254
originalQuery := string(message.Query)
244-
query, err := queryHandler.remapQuery(originalQuery)
255+
queryStatements, _, err := queryHandler.parseAndRemapQuery(originalQuery)
245256
if err != nil {
246257
LogError(queryHandler.config, "Couldn't map query:", originalQuery+"\n"+err.Error())
247258
return nil, nil, err
248259
}
260+
if len(queryStatements) > 1 {
261+
return nil, nil, errors.New("multiple queries in a single parse message are not supported")
262+
}
249263

250264
preparedStatement := &PreparedStatement{
251265
Name: message.Name,
252266
OriginalQuery: originalQuery,
253-
Query: query,
254267
ParameterOIDs: message.ParameterOIDs,
255268
}
256-
if query == "" {
269+
if len(queryStatements) == 0 {
257270
return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, preparedStatement, nil
258271
}
259272

273+
query := queryStatements[0]
274+
preparedStatement.Query = query
260275
statement, err := queryHandler.duckdb.PrepareContext(ctx, query)
261276
preparedStatement.Statement = statement
262277
if err != nil {
@@ -398,54 +413,72 @@ func (queryHandler *QueryHandler) rowsToDescriptionMessages(rows *sql.Rows, quer
398413
return messages, nil
399414
}
400415

401-
func (queryHandler *QueryHandler) rowsToDataMessages(rows *sql.Rows, originalQuery string) ([]pgproto3.Message, error) {
416+
func (queryHandler *QueryHandler) rowsToDataMessages(rows *sql.Rows, originalQueryStatement string) ([]pgproto3.Message, error) {
402417
cols, err := rows.ColumnTypes()
403418
if err != nil {
404-
LogError(queryHandler.config, "Couldn't get column types", originalQuery+"\n"+err.Error())
419+
LogError(queryHandler.config, "Couldn't get column types", originalQueryStatement+"\n"+err.Error())
405420
return nil, err
406421
}
407422

408423
var messages []pgproto3.Message
409424
for rows.Next() {
410425
dataRow, err := queryHandler.generateDataRow(rows, cols)
411426
if err != nil {
412-
LogError(queryHandler.config, "Couldn't get data row", originalQuery+"\n"+err.Error())
427+
LogError(queryHandler.config, "Couldn't get data row", originalQueryStatement+"\n"+err.Error())
413428
return nil, err
414429
}
415430
messages = append(messages, dataRow)
416431
}
417432

418433
commandTag := FALLBACK_SQL_QUERY
419434
switch {
420-
case strings.HasPrefix(originalQuery, "SET "):
435+
case strings.HasPrefix(originalQueryStatement, "SET "):
421436
commandTag = "SET"
422-
case strings.HasPrefix(originalQuery, "SHOW "):
437+
case strings.HasPrefix(originalQueryStatement, "SHOW "):
423438
commandTag = "SHOW"
424-
case strings.HasPrefix(originalQuery, "DISCARD ALL"):
439+
case strings.HasPrefix(originalQueryStatement, "DISCARD ALL"):
425440
commandTag = "DISCARD ALL"
426441
}
427442

428443
messages = append(messages, &pgproto3.CommandComplete{CommandTag: []byte(commandTag)})
429444
return messages, nil
430445
}
431446

432-
func (queryHandler *QueryHandler) remapQuery(query string) (string, error) {
447+
func (queryHandler *QueryHandler) parseAndRemapQuery(query string) ([]string, []string, error) {
433448
queryTree, err := pgQuery.Parse(query)
434449
if err != nil {
435450
LogError(queryHandler.config, "Error parsing query:", query+"\n"+err.Error())
436-
return "", err
451+
return nil, nil, err
437452
}
438453

439-
if strings.HasSuffix(query, " --INSPECT") {
454+
if strings.HasSuffix(query, INSPECT_SQL_COMMENT) {
440455
LogDebug(queryHandler.config, queryTree.Stmts)
441456
}
442457

443-
queryTree.Stmts, err = queryHandler.queryRemapper.RemapStatements(queryTree.Stmts)
458+
var originalQueryStatements []string
459+
for _, stmt := range queryTree.Stmts {
460+
originalQueryStatement, err := pgQuery.Deparse(&pgQuery.ParseResult{Stmts: []*pgQuery.RawStmt{stmt}})
461+
if err != nil {
462+
return nil, nil, err
463+
}
464+
originalQueryStatements = append(originalQueryStatements, originalQueryStatement)
465+
}
466+
467+
remappedStatements, err := queryHandler.queryRemapper.RemapStatements(queryTree.Stmts)
444468
if err != nil {
445-
return "", err
469+
return nil, nil, err
470+
}
471+
472+
var queryStatements []string
473+
for _, remappedStatement := range remappedStatements {
474+
queryStatement, err := pgQuery.Deparse(&pgQuery.ParseResult{Stmts: []*pgQuery.RawStmt{remappedStatement}})
475+
if err != nil {
476+
return nil, nil, err
477+
}
478+
queryStatements = append(queryStatements, queryStatement)
446479
}
447480

448-
return pgQuery.Deparse(queryTree)
481+
return queryStatements, originalQueryStatements, nil
449482
}
450483

451484
func (queryHandler *QueryHandler) generateRowDescription(cols []*sql.ColumnType) *pgproto3.RowDescription {

‎src/query_handler_test.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,12 @@ SET standard_conforming_strings = on;`
11411141
testNoError(t, err)
11421142
testMessageTypes(t, messages, []pgproto3.Message{
11431143
&pgproto3.CommandComplete{},
1144+
&pgproto3.CommandComplete{},
1145+
&pgproto3.CommandComplete{},
11441146
})
1147+
testCommandCompleteTag(t, messages[0], "SET")
1148+
testCommandCompleteTag(t, messages[1], "SET")
1149+
testCommandCompleteTag(t, messages[2], "SET")
11451150
})
11461151

11471152
t.Run("Handles mixed SET and SELECT statements", func(t *testing.T) {
@@ -1153,15 +1158,18 @@ SELECT passwd FROM pg_shadow WHERE usename='bemidb';`
11531158

11541159
testNoError(t, err)
11551160
testMessageTypes(t, messages, []pgproto3.Message{
1161+
&pgproto3.CommandComplete{},
11561162
&pgproto3.RowDescription{},
11571163
&pgproto3.DataRow{},
11581164
&pgproto3.CommandComplete{},
11591165
})
1160-
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
1166+
testCommandCompleteTag(t, messages[0], "SET")
1167+
testDataRowValues(t, messages[2], []string{"bemidb-encrypted"})
1168+
testCommandCompleteTag(t, messages[3], "SELECT 1")
11611169
})
11621170

11631171
t.Run("Handles multiple SELECT statements", func(t *testing.T) {
1164-
query := `SELECT passwd FROM pg_shadow WHERE usename='bemidb';
1172+
query := `SELECT 1;
11651173
SELECT passwd FROM pg_shadow WHERE usename='bemidb';`
11661174
queryHandler := initQueryHandler()
11671175

@@ -1172,8 +1180,14 @@ SELECT passwd FROM pg_shadow WHERE usename='bemidb';`
11721180
&pgproto3.RowDescription{},
11731181
&pgproto3.DataRow{},
11741182
&pgproto3.CommandComplete{},
1183+
&pgproto3.RowDescription{},
1184+
&pgproto3.DataRow{},
1185+
&pgproto3.CommandComplete{},
11751186
})
1176-
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
1187+
testDataRowValues(t, messages[1], []string{"1"})
1188+
testCommandCompleteTag(t, messages[2], "SELECT 1")
1189+
testDataRowValues(t, messages[4], []string{"bemidb-encrypted"})
1190+
testCommandCompleteTag(t, messages[5], "SELECT 1")
11771191
})
11781192

11791193
t.Run("Handles error in any of multiple statements", func(t *testing.T) {

‎src/query_remapper.go

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ func (remapper *QueryRemapper) RemapStatements(statements []*pgQuery.RawStmt) ([
6060
}
6161

6262
for i, stmt := range statements {
63+
LogTrace(remapper.config, "Remapping statement", i+1)
64+
6365
node := stmt.Stmt
6466

6567
switch {

0 commit comments

Comments
 (0)
Please sign in to comment.