1414package io .trino .gateway .ha .router ;
1515
1616import com .fasterxml .jackson .annotation .JsonCreator ;
17+ import com .fasterxml .jackson .annotation .JsonIgnore ;
1718import com .fasterxml .jackson .annotation .JsonProperty ;
1819import com .fasterxml .jackson .core .JsonGenerator ;
1920import com .fasterxml .jackson .databind .SerializerProvider ;
2021import com .fasterxml .jackson .databind .annotation .JsonSerialize ;
2122import com .fasterxml .jackson .databind .ser .std .StdSerializer ;
23+ import com .google .common .collect .ImmutableList ;
2224import com .google .common .collect .ImmutableMap ;
2325import com .google .common .collect .ImmutableSet ;
2426import io .airlift .compress .zstd .ZstdDecompressor ;
2527import io .airlift .json .JsonCodec ;
2628import io .airlift .log .Logger ;
27- import io .trino .gateway .ha .config .RequestAnalyzerConfig ;
2829import io .trino .sql .parser .ParsingException ;
2930import io .trino .sql .parser .SqlParser ;
3031import io .trino .sql .tree .AddColumn ;
3132import io .trino .sql .tree .Analyze ;
33+ import io .trino .sql .tree .Call ;
3234import io .trino .sql .tree .CreateCatalog ;
3335import io .trino .sql .tree .CreateMaterializedView ;
3436import io .trino .sql .tree .CreateSchema ;
4042import io .trino .sql .tree .DropTable ;
4143import io .trino .sql .tree .Execute ;
4244import io .trino .sql .tree .ExecuteImmediate ;
45+ import io .trino .sql .tree .Expression ;
4346import io .trino .sql .tree .Identifier ;
4447import io .trino .sql .tree .Node ;
4548import io .trino .sql .tree .NodeLocation ;
5760import io .trino .sql .tree .ShowSchemas ;
5861import io .trino .sql .tree .ShowTables ;
5962import io .trino .sql .tree .Statement ;
63+ import io .trino .sql .tree .StringLiteral ;
6064import io .trino .sql .tree .Table ;
6165import io .trino .sql .tree .TableFunctionInvocation ;
6266import jakarta .servlet .http .HttpServletRequest ;
7377import java .util .Set ;
7478import java .util .stream .Collectors ;
7579
80+ import static com .google .common .base .Preconditions .checkArgument ;
7681import static com .google .common .io .BaseEncoding .base64Url ;
7782import static io .airlift .json .JsonCodec .jsonCodec ;
7883import static java .lang .Math .toIntExact ;
@@ -85,6 +90,7 @@ public class TrinoQueryProperties
8590{
8691 private final Logger log = Logger .get (TrinoQueryProperties .class );
8792 private final boolean isClientsUseV2Format ;
93+ private final int maxBodySize ;
8894 private String body = "" ;
8995 private String queryType = "" ;
9096 private String resourceGroupQueryType = "" ;
@@ -96,6 +102,7 @@ public class TrinoQueryProperties
96102 private Set <String > catalogSchemas = ImmutableSet .of ();
97103 private boolean isNewQuerySubmission ;
98104 private Optional <String > errorMessage = Optional .empty ();
105+ private Optional <String > queryId = Optional .empty ();
99106
100107 public static final String TRINO_CATALOG_HEADER_NAME = "X-Trino-Catalog" ;
101108 public static final String TRINO_SCHEMA_HEADER_NAME = "X-Trino-Schema" ;
@@ -128,21 +135,24 @@ public TrinoQueryProperties(
128135 this .isNewQuerySubmission = isNewQuerySubmission ;
129136 this .errorMessage = requireNonNullElse (errorMessage , Optional .empty ());
130137 isClientsUseV2Format = false ;
138+ maxBodySize = -1 ;
131139 }
132140
133- public TrinoQueryProperties (HttpServletRequest request , RequestAnalyzerConfig config )
141+ public TrinoQueryProperties (HttpServletRequest request , boolean isClientsUseV2Format , int maxBodySize )
134142 {
135- isClientsUseV2Format = config .isClientsUseV2Format ();
143+ requireNonNull (request , "request is null" );
144+ this .isClientsUseV2Format = isClientsUseV2Format ;
145+ this .maxBodySize = maxBodySize ;
136146
137147 defaultCatalog = Optional .ofNullable (request .getHeader (TRINO_CATALOG_HEADER_NAME ));
138148 defaultSchema = Optional .ofNullable (request .getHeader (TRINO_SCHEMA_HEADER_NAME ));
139149 if (request .getMethod ().equals (HttpMethod .POST )) {
140150 isNewQuerySubmission = true ;
141- processRequestBody (request , config );
151+ processRequestBody (request );
142152 }
143153 }
144154
145- private void processRequestBody (HttpServletRequest request , RequestAnalyzerConfig config )
155+ private void processRequestBody (HttpServletRequest request )
146156 {
147157 try (BufferedReader reader = request .getReader ()) {
148158 if (reader == null ) {
@@ -153,11 +163,11 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi
153163
154164 Map <String , String > preparedStatements = getPreparedStatements (request );
155165 SqlParser parser = new SqlParser ();
156- reader .mark (config . getMaxBodySize () );
157- char [] buffer = new char [config . getMaxBodySize () ];
158- int nChars = reader .read (buffer , 0 , config . getMaxBodySize () );
166+ reader .mark (maxBodySize );
167+ char [] buffer = new char [maxBodySize ];
168+ int nChars = reader .read (buffer , 0 , maxBodySize );
159169 reader .reset ();
160- if (nChars == config . getMaxBodySize () ) {
170+ if (nChars == maxBodySize ) {
161171 log .warn ("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected" );
162172 return ;
163173 //The body is truncated - there is a chance that it could still be syntactically valid SQL, for example if truncated on
@@ -199,7 +209,7 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
199209 ImmutableSet .Builder <String > schemaBuilder = ImmutableSet .builder ();
200210 ImmutableSet .Builder <String > catalogSchemaBuilder = ImmutableSet .builder ();
201211
202- getNames (statement , tableBuilder , catalogBuilder , schemaBuilder , catalogSchemaBuilder );
212+ visitNode (statement , tableBuilder , catalogBuilder , schemaBuilder , catalogSchemaBuilder );
203213 tables = tableBuilder .build ();
204214 catalogBuilder .addAll (tables .stream ().map (q -> q .getParts ().getFirst ()).iterator ());
205215 catalogs = catalogBuilder .build ();
@@ -260,7 +270,7 @@ private String decodePreparedStatementFromHeader(String headerValue)
260270 return new String (preparedStatement , UTF_8 );
261271 }
262272
263- private void getNames (Node node , ImmutableSet .Builder <QualifiedName > tableBuilder ,
273+ private void visitNode (Node node , ImmutableSet .Builder <QualifiedName > tableBuilder ,
264274 ImmutableSet .Builder <String > catalogBuilder ,
265275 ImmutableSet .Builder <String > schemaBuilder ,
266276 ImmutableSet .Builder <String > catalogSchemaBuilder )
@@ -269,6 +279,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
269279 switch (node ) {
270280 case AddColumn s -> tableBuilder .add (qualifyName (s .getName ()));
271281 case Analyze s -> tableBuilder .add (qualifyName (s .getTableName ()));
282+ case Call call -> queryId = extractQueryIdFromCall (call );
272283 case CreateCatalog s -> catalogBuilder .add (s .getCatalogName ().getValue ());
273284 case CreateMaterializedView s -> tableBuilder .add (qualifyName (s .getName ()));
274285 case CreateSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName (Optional .of (s .getSchemaName ()), catalogBuilder , schemaBuilder , catalogSchemaBuilder );
@@ -342,10 +353,22 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
342353 }
343354
344355 for (Node child : node .getChildren ()) {
345- getNames (child , tableBuilder , catalogBuilder , schemaBuilder , catalogSchemaBuilder );
356+ visitNode (child , tableBuilder , catalogBuilder , schemaBuilder , catalogSchemaBuilder );
346357 }
347358 }
348359
360+ private Optional <String > extractQueryIdFromCall (Call call )
361+ throws RequestParsingException
362+ {
363+ QualifiedName callName = qualifyName (call .getName ());
364+ if (callName .equals (QualifiedName .of ("system" , "runtime" , "kill_query" ))) {
365+ Expression argument = call .getArguments ().getFirst ().getValue ();
366+ checkArgument (argument instanceof StringLiteral , "Unable to route kill_query procedures where the first argument is not a String Literal" );
367+ return Optional .of (((StringLiteral ) argument ).getValue ());
368+ }
369+ return Optional .empty ();
370+ }
371+
349372 private void setCatalogAndSchemaNameFromSchemaQualifiedName (
350373 Optional <QualifiedName > schemaOptional ,
351374 ImmutableSet .Builder <String > catalogBuilder ,
@@ -381,15 +404,16 @@ private RequestParsingException unsetDefaultExceptionSupplier()
381404 return new RequestParsingException ("Name not fully qualified" );
382405 }
383406
384- private QualifiedName qualifyName (QualifiedName table )
407+ private QualifiedName qualifyName (QualifiedName name )
385408 throws RequestParsingException
386409 {
387- List <String > tableParts = table .getParts ();
388- return switch (tableParts .size ()) {
389- case 1 -> QualifiedName .of (defaultCatalog .orElseThrow (this ::unsetDefaultExceptionSupplier ), defaultSchema .orElseThrow (this ::unsetDefaultExceptionSupplier ), tableParts .getFirst ());
390- case 2 -> QualifiedName .of (defaultCatalog .orElseThrow (this ::unsetDefaultExceptionSupplier ), tableParts .getFirst (), tableParts .get (1 ));
391- case 3 -> QualifiedName .of (tableParts .getFirst (), tableParts .get (1 ), tableParts .get (2 ));
392- default -> throw new RequestParsingException ("Unexpected table name: " + table .getParts ());
410+ List <String > nameParts = name .getParts ();
411+ return switch (nameParts .size ()) {
412+ case 1 ->
413+ QualifiedName .of (defaultCatalog .orElseThrow (this ::unsetDefaultExceptionSupplier ), defaultSchema .orElseThrow (this ::unsetDefaultExceptionSupplier ), nameParts .getFirst ());
414+ case 2 -> QualifiedName .of (defaultCatalog .orElseThrow (this ::unsetDefaultExceptionSupplier ), nameParts .getFirst (), nameParts .get (1 ));
415+ case 3 -> QualifiedName .of (nameParts .getFirst (), nameParts .get (1 ), nameParts .get (2 ));
416+ default -> throw new RequestParsingException ("Unexpected qualified name: " + name .getParts ());
393417 };
394418 }
395419
@@ -520,6 +544,12 @@ public Optional<String> getErrorMessage()
520544 return errorMessage ;
521545 }
522546
547+ @ JsonIgnore
548+ public Optional <String > getQueryId ()
549+ {
550+ return queryId ;
551+ }
552+
523553 public static class AlternateStatementRequestBodyFormat
524554 {
525555 // Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session
0 commit comments