@@ -19,7 +19,8 @@ import (
19
19
)
20
20
21
21
const (
22
- FALLBACK_SQL_QUERY = "SELECT 1"
22
+ FALLBACK_SQL_QUERY = "SELECT 1"
23
+ INSPECT_SQL_COMMENT = " --INSPECT"
23
24
)
24
25
25
26
type QueryHandler struct {
@@ -199,64 +200,78 @@ func NewQueryHandler(config *Config, duckdb *Duckdb, icebergReader *IcebergReade
199
200
}
200
201
201
202
func (queryHandler * QueryHandler ) HandleQuery (originalQuery string ) ([]pgproto3.Message , error ) {
202
- query , err := queryHandler .remapQuery (originalQuery )
203
+ queryStatements , originalQueryStatements , err := queryHandler .parseAndRemapQuery (originalQuery )
203
204
if err != nil {
204
205
LogError (queryHandler .config , "Couldn't map query:" , originalQuery + "\n " + err .Error ())
205
206
return nil , err
206
207
}
207
-
208
- if query == "" {
208
+ if len (queryStatements ) == 0 {
209
209
return []pgproto3.Message {& pgproto3.EmptyQueryResponse {}}, nil
210
210
}
211
211
212
- rows , err := queryHandler .duckdb .QueryContext (context .Background (), query )
213
- if err != nil {
214
- errorMessage := err .Error ()
215
-
216
- if errorMessage == "Binder Error: UNNEST requires a single list as input" {
217
- // https://github.com/duckdb/duckdb/issues/11693
218
- LogWarn (queryHandler .config , "Couldn't handle query via DuckDB:" , query + "\n " + err .Error ())
219
- return queryHandler .HandleQuery (FALLBACK_SQL_QUERY )
220
- } else {
221
- LogError (queryHandler .config , "Couldn't handle query via DuckDB:" , query + "\n " + err .Error ())
212
+ var queriesMessages []pgproto3.Message
213
+
214
+ for i , queryStatement := range queryStatements {
215
+ rows , err := queryHandler .duckdb .QueryContext (context .Background (), queryStatement )
216
+ if err != nil {
217
+ errorMessage := err .Error ()
218
+ if errorMessage == "Binder Error: UNNEST requires a single list as input" {
219
+ // https://github.com/duckdb/duckdb/issues/11693
220
+ LogWarn (queryHandler .config , "Couldn't handle query via DuckDB:" , queryStatement + "\n " + err .Error ())
221
+ queriesMsgs , err := queryHandler .HandleQuery (FALLBACK_SQL_QUERY ) // self-recursion
222
+ if err != nil {
223
+ return nil , err
224
+ }
225
+ queriesMessages = append (queriesMessages , queriesMsgs ... )
226
+ continue
227
+ } else {
228
+ LogError (queryHandler .config , "Couldn't handle query via DuckDB:" , queryStatement + "\n " + err .Error ())
229
+ return nil , err
230
+ }
231
+ }
232
+ defer rows .Close ()
233
+
234
+ var queryMessages []pgproto3.Message
235
+ descriptionMessages , err := queryHandler .rowsToDescriptionMessages (rows , queryStatement )
236
+ if err != nil {
222
237
return nil , err
223
238
}
224
- }
225
- defer rows .Close ()
239
+ queryMessages = append (queryMessages , descriptionMessages ... )
240
+ dataMessages , err := queryHandler .rowsToDataMessages (rows , originalQueryStatements [i ])
241
+ if err != nil {
242
+ return nil , err
243
+ }
244
+ queryMessages = append (queryMessages , dataMessages ... )
226
245
227
- var messages []pgproto3.Message
228
- descriptionMessages , err := queryHandler .rowsToDescriptionMessages (rows , query )
229
- if err != nil {
230
- return nil , err
246
+ queriesMessages = append (queriesMessages , queryMessages ... )
231
247
}
232
- messages = append (messages , descriptionMessages ... )
233
- dataMessages , err := queryHandler .rowsToDataMessages (rows , originalQuery )
234
- if err != nil {
235
- return nil , err
236
- }
237
- messages = append (messages , dataMessages ... )
238
- return messages , nil
248
+
249
+ return queriesMessages , nil
239
250
}
240
251
241
252
func (queryHandler * QueryHandler ) HandleParseQuery (message * pgproto3.Parse ) ([]pgproto3.Message , * PreparedStatement , error ) {
242
253
ctx := context .Background ()
243
254
originalQuery := string (message .Query )
244
- query , err := queryHandler .remapQuery (originalQuery )
255
+ queryStatements , _ , err := queryHandler .parseAndRemapQuery (originalQuery )
245
256
if err != nil {
246
257
LogError (queryHandler .config , "Couldn't map query:" , originalQuery + "\n " + err .Error ())
247
258
return nil , nil , err
248
259
}
260
+ if len (queryStatements ) > 1 {
261
+ return nil , nil , errors .New ("multiple queries in a single parse message are not supported" )
262
+ }
249
263
250
264
preparedStatement := & PreparedStatement {
251
265
Name : message .Name ,
252
266
OriginalQuery : originalQuery ,
253
- Query : query ,
254
267
ParameterOIDs : message .ParameterOIDs ,
255
268
}
256
- if query == "" {
269
+ if len ( queryStatements ) == 0 {
257
270
return []pgproto3.Message {& pgproto3.EmptyQueryResponse {}}, preparedStatement , nil
258
271
}
259
272
273
+ query := queryStatements [0 ]
274
+ preparedStatement .Query = query
260
275
statement , err := queryHandler .duckdb .PrepareContext (ctx , query )
261
276
preparedStatement .Statement = statement
262
277
if err != nil {
@@ -398,54 +413,72 @@ func (queryHandler *QueryHandler) rowsToDescriptionMessages(rows *sql.Rows, quer
398
413
return messages , nil
399
414
}
400
415
401
- func (queryHandler * QueryHandler ) rowsToDataMessages (rows * sql.Rows , originalQuery string ) ([]pgproto3.Message , error ) {
416
+ func (queryHandler * QueryHandler ) rowsToDataMessages (rows * sql.Rows , originalQueryStatement string ) ([]pgproto3.Message , error ) {
402
417
cols , err := rows .ColumnTypes ()
403
418
if err != nil {
404
- LogError (queryHandler .config , "Couldn't get column types" , originalQuery + "\n " + err .Error ())
419
+ LogError (queryHandler .config , "Couldn't get column types" , originalQueryStatement + "\n " + err .Error ())
405
420
return nil , err
406
421
}
407
422
408
423
var messages []pgproto3.Message
409
424
for rows .Next () {
410
425
dataRow , err := queryHandler .generateDataRow (rows , cols )
411
426
if err != nil {
412
- LogError (queryHandler .config , "Couldn't get data row" , originalQuery + "\n " + err .Error ())
427
+ LogError (queryHandler .config , "Couldn't get data row" , originalQueryStatement + "\n " + err .Error ())
413
428
return nil , err
414
429
}
415
430
messages = append (messages , dataRow )
416
431
}
417
432
418
433
commandTag := FALLBACK_SQL_QUERY
419
434
switch {
420
- case strings .HasPrefix (originalQuery , "SET " ):
435
+ case strings .HasPrefix (originalQueryStatement , "SET " ):
421
436
commandTag = "SET"
422
- case strings .HasPrefix (originalQuery , "SHOW " ):
437
+ case strings .HasPrefix (originalQueryStatement , "SHOW " ):
423
438
commandTag = "SHOW"
424
- case strings .HasPrefix (originalQuery , "DISCARD ALL" ):
439
+ case strings .HasPrefix (originalQueryStatement , "DISCARD ALL" ):
425
440
commandTag = "DISCARD ALL"
426
441
}
427
442
428
443
messages = append (messages , & pgproto3.CommandComplete {CommandTag : []byte (commandTag )})
429
444
return messages , nil
430
445
}
431
446
432
- func (queryHandler * QueryHandler ) remapQuery (query string ) (string , error ) {
447
+ func (queryHandler * QueryHandler ) parseAndRemapQuery (query string ) ([] string , [] string , error ) {
433
448
queryTree , err := pgQuery .Parse (query )
434
449
if err != nil {
435
450
LogError (queryHandler .config , "Error parsing query:" , query + "\n " + err .Error ())
436
- return "" , err
451
+ return nil , nil , err
437
452
}
438
453
439
- if strings .HasSuffix (query , " --INSPECT" ) {
454
+ if strings .HasSuffix (query , INSPECT_SQL_COMMENT ) {
440
455
LogDebug (queryHandler .config , queryTree .Stmts )
441
456
}
442
457
443
- queryTree .Stmts , err = queryHandler .queryRemapper .RemapStatements (queryTree .Stmts )
458
+ var originalQueryStatements []string
459
+ for _ , stmt := range queryTree .Stmts {
460
+ originalQueryStatement , err := pgQuery .Deparse (& pgQuery.ParseResult {Stmts : []* pgQuery.RawStmt {stmt }})
461
+ if err != nil {
462
+ return nil , nil , err
463
+ }
464
+ originalQueryStatements = append (originalQueryStatements , originalQueryStatement )
465
+ }
466
+
467
+ remappedStatements , err := queryHandler .queryRemapper .RemapStatements (queryTree .Stmts )
444
468
if err != nil {
445
- return "" , err
469
+ return nil , nil , err
470
+ }
471
+
472
+ var queryStatements []string
473
+ for _ , remappedStatement := range remappedStatements {
474
+ queryStatement , err := pgQuery .Deparse (& pgQuery.ParseResult {Stmts : []* pgQuery.RawStmt {remappedStatement }})
475
+ if err != nil {
476
+ return nil , nil , err
477
+ }
478
+ queryStatements = append (queryStatements , queryStatement )
446
479
}
447
480
448
- return pgQuery . Deparse ( queryTree )
481
+ return queryStatements , originalQueryStatements , nil
449
482
}
450
483
451
484
func (queryHandler * QueryHandler ) generateRowDescription (cols []* sql.ColumnType ) * pgproto3.RowDescription {
0 commit comments