1616import com .google .common .base .Splitter ;
1717import com .google .common .io .CharStreams ;
1818import io .airlift .log .Logger ;
19+ import io .trino .gateway .ha .router .SerializableExpression ;
20+ import io .trino .gateway .ha .router .TrinoQueryProperties ;
21+ import io .trino .sql .tree .Expression ;
22+ import io .trino .sql .tree .StringLiteral ;
1923import jakarta .servlet .http .HttpServletRequest ;
24+ import jakarta .ws .rs .HttpMethod ;
2025
26+ import java .io .IOException ;
2127import java .io .InputStreamReader ;
2228import java .util .Base64 ;
2329import java .util .List ;
2430import java .util .Optional ;
2531import java .util .regex .Matcher ;
2632import java .util .regex .Pattern ;
2733
34+ import static com .google .common .base .Preconditions .checkArgument ;
2835import static com .google .common .base .Strings .isNullOrEmpty ;
2936import static io .trino .gateway .ha .handler .QueryIdCachingProxyHandler .TRINO_UI_PATH ;
3037import static io .trino .gateway .ha .handler .QueryIdCachingProxyHandler .USER_HEADER ;
@@ -52,7 +59,6 @@ public final class ProxyUtils
5259 * capitalization.
5360 */
5461 private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern .compile (".*(?:%2F|(?i)query_?id(?-i)=|^)(\\ d+_\\ d+_\\ d+_\\ w+).*" );
55- private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern .compile ("'([^\\ s']+)'" );
5662
5763 private ProxyUtils () {}
5864
@@ -91,47 +97,51 @@ public static String getQueryUser(String userHeader, String authorization)
9197 return parts .get (0 );
9298 }
9399
94- public static String extractQueryIdIfPresent (HttpServletRequest request , List <String > statementPaths )
100+ public static Optional <String > extractQueryIdIfPresent (
101+ HttpServletRequest request ,
102+ List <String > statementPaths ,
103+ boolean requestAnalyserClientsUseV2Format ,
104+ int requestAnalyserMaxBodySize )
95105 {
96106 String path = request .getRequestURI ();
97107 String queryParams = request .getQueryString ();
108+ if (!request .getMethod ().equals (HttpMethod .POST )) {
109+ return extractQueryIdIfPresent (path , queryParams , statementPaths );
110+ }
111+ String queryText ;
98112 try {
99- String queryText = CharStreams .toString (new InputStreamReader (request .getInputStream (), UTF_8 ));
100- if (!isNullOrEmpty (queryText )
101- && queryText .toLowerCase (ENGLISH ).contains ("system.runtime.kill_query" )) {
102- // extract and return the queryId
103- String [] parts = queryText .split ("," );
104- for (String part : parts ) {
105- if (part .contains ("query_id" )) {
106- Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES .matcher (part );
107- if (matcher .find ()) {
108- String queryQuoted = matcher .group ();
109- if (!isNullOrEmpty (queryQuoted ) && queryQuoted .length () > 0 ) {
110- return queryQuoted .substring (1 , queryQuoted .length () - 1 );
111- }
112- }
113- }
114- }
115- }
113+ queryText = CharStreams .toString (new InputStreamReader (request .getInputStream (), UTF_8 ));
116114 }
117- catch (Exception e ) {
118- log . error ( e , "Error extracting query payload from request" );
115+ catch (IOException e ) {
116+ throw new RuntimeException ( "Error reading request body" , e );
119117 }
120-
121- return extractQueryIdIfPresent (path , queryParams , statementPaths );
118+ if (!isNullOrEmpty (queryText ) && queryText .toLowerCase (ENGLISH ).contains ("kill_query" )) {
119+ TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties (request , requestAnalyserClientsUseV2Format , requestAnalyserMaxBodySize );
120+ if (trinoQueryProperties .procedureNameEquals ("system.runtime.kill_query" )) {
121+ SerializableExpression argument = trinoQueryProperties .getProcedureArguments ().getFirst ().getValue ();
122+ checkArgument (argument .getOriginalClass ().equals (StringLiteral .class ), "Unable to route kill_query procedures where the first argument is not a String Literal" );
123+ return Optional .of (argument .getValue ());
124+ }
125+ }
126+ return Optional .empty ();
122127 }
123128
124- public static String extractQueryIdIfPresent (String path , String queryParams , List <String > statementPaths )
129+ public static Optional < String > extractQueryIdIfPresent (String path , String queryParams , List <String > statementPaths )
125130 {
126131 if (path == null ) {
127- return null ;
132+ return Optional . empty () ;
128133 }
129- String queryId = null ;
130134 log .debug ("Trying to extract query id from path [%s] or queryString [%s]" , path , queryParams );
131135 // matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number,
132136 // and if custom paths are supplied using the statementPaths configuration, paths such as
133137 // /custom/statement/path/executing/query_id/nonce/sequence_number
134138 Optional <String > matchingStatementPath = statementPaths .stream ().filter (path ::startsWith ).findAny ();
139+ if (!isNullOrEmpty (queryParams )) {
140+ Matcher matcher = QUERY_ID_PARAM_PATTERN .matcher (queryParams );
141+ if (matcher .matches ()) {
142+ return Optional .of (matcher .group (1 ));
143+ }
144+ }
135145 if (matchingStatementPath .isPresent () || path .startsWith (V1_QUERY_PATH )) {
136146 path = path .replace (matchingStatementPath .orElse (V1_QUERY_PATH ), "" );
137147 String [] tokens = path .split ("/" );
@@ -140,27 +150,20 @@ public static String extractQueryIdIfPresent(String path, String queryParams, Li
140150 || tokens [1 ].equals ("scheduled" )
141151 || tokens [1 ].equals ("executing" )
142152 || tokens [1 ].equals ("partialCancel" )) {
143- queryId = tokens [2 ];
153+ return Optional . of ( tokens [2 ]) ;
144154 }
145155 else {
146- queryId = tokens [1 ];
156+ return Optional . of ( tokens [1 ]) ;
147157 }
148158 }
149159 }
150160 else if (path .startsWith (TRINO_UI_PATH )) {
151161 Matcher matcher = QUERY_ID_PATH_PATTERN .matcher (path );
152162 if (matcher .matches ()) {
153- queryId = matcher .group (1 );
154- }
155- }
156- if (!isNullOrEmpty (queryParams )) {
157- Matcher matcher = QUERY_ID_PARAM_PATTERN .matcher (queryParams );
158- if (matcher .matches ()) {
159- queryId = matcher .group (1 );
163+ return Optional .of (matcher .group (1 ));
160164 }
161165 }
162- log .debug ("Query id in URL [%s]" , queryId );
163- return queryId ;
166+ return Optional .empty ();
164167 }
165168
166169 public static String buildUriWithNewBackend (String backendHost , HttpServletRequest request )
0 commit comments