Skip to content

Commit e805660

Browse files
committed
Extract query ID from all kill_query procedure variations
1 parent 23e8320 commit e805660

File tree

5 files changed

+227
-18
lines changed

5 files changed

+227
-18
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import com.google.common.base.Splitter;
1717
import com.google.common.io.CharStreams;
1818
import io.airlift.log.Logger;
19+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
20+
import io.trino.gateway.ha.router.TrinoQueryProperties;
1921
import jakarta.servlet.http.HttpServletRequest;
2022

2123
import java.io.InputStreamReader;
@@ -50,7 +52,6 @@ public final class ProxyUtils
5052
* capitalization.
5153
*/
5254
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
53-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
5455

5556
private ProxyUtils() {}
5657

@@ -89,26 +90,19 @@ public static String getQueryUser(String userHeader, String authorization)
8990
return parts.get(0);
9091
}
9192

92-
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
93+
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths, RequestAnalyzerConfig requestAnalyzerConfig)
9394
{
9495
String path = request.getRequestURI();
9596
String queryParams = request.getQueryString();
9697
try {
9798
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
9899
if (!isNullOrEmpty(queryText)
99-
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
100-
// extract and return the queryId
101-
String[] parts = queryText.split(",");
102-
for (String part : parts) {
103-
if (part.contains("query_id")) {
104-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
105-
if (matcher.find()) {
106-
String queryQuoted = matcher.group();
107-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
108-
return queryQuoted.substring(1, queryQuoted.length() - 1);
109-
}
110-
}
111-
}
100+
&& queryText.toLowerCase().contains("kill_query")) {
101+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
102+
if (trinoQueryProperties.getProcedure().isPresent()
103+
&& trinoQueryProperties.getProcedure().orElseThrow().getName().getParts().getLast().equalsIgnoreCase("kill_query")) {
104+
return trinoQueryProperties.getProcedureArgs().getFirst().getValue().toString()
105+
.replaceFirst("^'", "").replaceFirst("'$", "");
112106
}
113107
}
114108
}

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import io.airlift.log.Logger;
1717
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
18+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
1819
import io.trino.gateway.ha.router.GatewayCookie;
1920
import io.trino.gateway.ha.router.RoutingGroupSelector;
2021
import io.trino.gateway.ha.router.RoutingManager;
@@ -45,18 +46,22 @@ public class RoutingTargetHandler
4546
private final RoutingGroupSelector routingGroupSelector;
4647
private final List<String> statementPaths;
4748
private final List<Pattern> extraWhitelistPaths;
49+
private final RequestAnalyzerConfig requestAnalyzerConfig;
4850
private final boolean cookiesEnabled;
4951

5052
public RoutingTargetHandler(
5153
RoutingManager routingManager,
5254
RoutingGroupSelector routingGroupSelector,
5355
List<String> statementPaths,
54-
List<String> extraWhitelistPaths)
56+
List<String> extraWhitelistPaths,
57+
RequestAnalyzerConfig requestAnalyzerConfig)
5558
{
5659
this.routingManager = requireNonNull(routingManager);
60+
5761
this.routingGroupSelector = requireNonNull(routingGroupSelector);
5862
this.statementPaths = requireNonNull(statementPaths);
5963
this.extraWhitelistPaths = extraWhitelistPaths.stream().map(Pattern::compile).collect(toImmutableList());
64+
this.requestAnalyzerConfig = requestAnalyzerConfig;
6065
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
6166
}
6267

@@ -94,7 +99,7 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
9499

95100
private Optional<String> getPreviousBackend(HttpServletRequest request)
96101
{
97-
String queryId = extractQueryIdIfPresent(request, statementPaths);
102+
String queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyzerConfig);
98103
if (!isNullOrEmpty(queryId)) {
99104
return Optional.of(routingManager.findBackendForQueryId(queryId));
100105
}

gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ public RoutingTargetHandler getRoutingTargetHandler(
211211
routingManager,
212212
routingGroupSelector,
213213
configuration.getStatementPaths(),
214-
configuration.getExtraWhitelistPaths());
214+
configuration.getExtraWhitelistPaths(),
215+
configuration.getRequestAnalyzerConfig());
215216
}
216217
}

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.fasterxml.jackson.databind.SerializerProvider;
2020
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
2121
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
22+
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.ImmutableMap;
2324
import com.google.common.collect.ImmutableSet;
2425
import io.airlift.compress.zstd.ZstdDecompressor;
@@ -29,6 +30,8 @@
2930
import io.trino.sql.parser.SqlParser;
3031
import io.trino.sql.tree.AddColumn;
3132
import io.trino.sql.tree.Analyze;
33+
import io.trino.sql.tree.Call;
34+
import io.trino.sql.tree.CallArgument;
3235
import io.trino.sql.tree.CreateCatalog;
3336
import io.trino.sql.tree.CreateMaterializedView;
3437
import io.trino.sql.tree.CreateSchema;
@@ -94,6 +97,8 @@ public class TrinoQueryProperties
9497
private Set<String> catalogs = ImmutableSet.of();
9598
private Set<String> schemas = ImmutableSet.of();
9699
private Set<String> catalogSchemas = ImmutableSet.of();
100+
private Optional<Call> procedure = Optional.empty();
101+
private List<CallArgument> procedureArgs = ImmutableList.of();
97102
private boolean isNewQuerySubmission;
98103
private boolean isQueryParsingSuccessful;
99104

@@ -262,6 +267,11 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
262267
ImmutableSet.Builder<String> catalogSchemaBuilder)
263268
throws RequestParsingException
264269
{
270+
if (node instanceof Call) {
271+
procedure = Optional.of((Call) node);
272+
procedureArgs = ((Call) node).getArguments();
273+
return;
274+
}
265275
switch (node) {
266276
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
267277
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
@@ -513,6 +523,16 @@ public boolean isQueryParsingSuccessful()
513523
return isQueryParsingSuccessful;
514524
}
515525

526+
public Optional<Call> getProcedure()
527+
{
528+
return procedure;
529+
}
530+
531+
public List<CallArgument> getProcedureArgs()
532+
{
533+
return procedureArgs;
534+
}
535+
516536
public static class AlternateStatementRequestBodyFormat
517537
{
518538
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session

gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,29 @@
1414
package io.trino.gateway.ha.handler;
1515

1616
import com.google.common.collect.ImmutableList;
17+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
18+
import jakarta.servlet.ReadListener;
19+
import jakarta.servlet.ServletInputStream;
1720
import jakarta.servlet.http.HttpServletRequest;
21+
import jakarta.ws.rs.HttpMethod;
1822
import org.junit.jupiter.api.Test;
1923
import org.junit.jupiter.api.TestInstance;
2024
import org.junit.jupiter.api.TestInstance.Lifecycle;
2125
import org.mockito.Mockito;
2226

27+
import java.io.BufferedReader;
28+
import java.io.ByteArrayInputStream;
2329
import java.io.IOException;
30+
import java.io.StringReader;
2431
import java.util.List;
2532

2633
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
2734
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
2835
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
2936
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
37+
import static java.nio.charset.StandardCharsets.UTF_8;
3038
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.Mockito.when;
3140

3241
@TestInstance(Lifecycle.PER_CLASS)
3342
public class TestQueryIdCachingProxyHandler
@@ -66,6 +75,186 @@ public void testExtractQueryIdFromUrl()
6675
.isNull();
6776
}
6877

78+
@Test
79+
void testQueryIdFromKill()
80+
throws IOException
81+
{
82+
RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
83+
requestAnalyzerConfig.setAnalyzeRequest(true);
84+
assertThat(
85+
extractQueryIdIfPresent(
86+
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"),
87+
ImmutableList.of(), requestAnalyzerConfig))
88+
.isEqualTo("20200416_160256_03078_6b4yt");
89+
90+
assertThat(
91+
extractQueryIdIfPresent(
92+
prepareMockRequestWithBody("CALL system.runtime.kill_query(Query_id => '20200416_160256_03078_6b4yt', Message => 'If he dies, he dies')"),
93+
ImmutableList.of(),
94+
requestAnalyzerConfig))
95+
.isEqualTo("20200416_160256_03078_6b4yt");
96+
97+
assertThat(
98+
extractQueryIdIfPresent(
99+
prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')"),
100+
ImmutableList.of(),
101+
requestAnalyzerConfig))
102+
.isEqualTo("20200416_160256_03078_6b4yt");
103+
104+
assertThat(
105+
extractQueryIdIfPresent(
106+
prepareMockRequestWithBody("CALL runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"),
107+
ImmutableList.of(),
108+
requestAnalyzerConfig))
109+
.isEqualTo("20200416_160256_03078_6b4yt");
110+
111+
assertThat(
112+
extractQueryIdIfPresent(
113+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_7n5uy'')')"),
114+
ImmutableList.of(),
115+
requestAnalyzerConfig))
116+
.isEqualTo("20200416_160256_03078_6b4yt");
117+
118+
assertThat(
119+
extractQueryIdIfPresent(
120+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"),
121+
ImmutableList.of(),
122+
requestAnalyzerConfig))
123+
.isEqualTo("20200416_160256_03078_6b4yt");
124+
125+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')"),
126+
ImmutableList.of(),
127+
requestAnalyzerConfig))
128+
.isEqualTo("20200416_160256_03078_6b4yt");
129+
130+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')"),
131+
ImmutableList.of(),
132+
requestAnalyzerConfig))
133+
.isEqualTo("20200416_160256_03078_6b4yt");
134+
135+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt')"), ImmutableList.of(), requestAnalyzerConfig))
136+
.isEqualTo("20200416_160256_03078_6b4yt");
137+
138+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("call Kill_Query('20200416_160256_03078_6b4yt')"), ImmutableList.of(), requestAnalyzerConfig))
139+
.isEqualTo("20200416_160256_03078_6b4yt");
140+
141+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("SELECT * FROM postgres.query_logs.queries WHERE sql LIKE '%kill_query(''20200416_160256%' "),
142+
ImmutableList.of(),
143+
requestAnalyzerConfig))
144+
.isNull();
145+
146+
assertThat(extractQueryIdIfPresent(
147+
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql like '%kill_query(''20200416_160256_03078_6b4yt' "),
148+
ImmutableList.of(),
149+
requestAnalyzerConfig))
150+
.isNull();
151+
152+
assertThat(extractQueryIdIfPresent(
153+
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE 'CALL kill_query(_20200416_160256_03078_6b4yt_)' "),
154+
ImmutableList.of(),
155+
requestAnalyzerConfig))
156+
.isNull();
157+
158+
assertThat(
159+
extractQueryIdIfPresent(
160+
prepareMockRequestWithBody("""
161+
--CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')
162+
SELECT 1
163+
"""),
164+
ImmutableList.of(),
165+
requestAnalyzerConfig))
166+
.isNull();
167+
168+
assertThat(
169+
extractQueryIdIfPresent(
170+
prepareMockRequestWithBody("""
171+
/*
172+
CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')
173+
*/
174+
SELECT 1
175+
"""),
176+
ImmutableList.of(),
177+
requestAnalyzerConfig))
178+
.isNull();
179+
180+
assertThat(
181+
extractQueryIdIfPresent(
182+
prepareMockRequestWithBody("""
183+
CALL KILL_QUERY('20200416_160256_03078_6b4yt', 'If he dies, he dies')
184+
"""),
185+
ImmutableList.of(),
186+
requestAnalyzerConfig))
187+
.isEqualTo("20200416_160256_03078_6b4yt");
188+
189+
assertThat(
190+
extractQueryIdIfPresent(
191+
prepareMockRequestWithBody("""
192+
CALL KILL_QUERY ('20200416_160256_03078_6b4yt', 'If he dies, he dies')
193+
"""),
194+
ImmutableList.of(),
195+
requestAnalyzerConfig))
196+
.isEqualTo("20200416_160256_03078_6b4yt");
197+
198+
assertThat(
199+
extractQueryIdIfPresent(
200+
prepareMockRequestWithBody("""
201+
CALL
202+
KILL_QUERY
203+
(
204+
-- this is a comment
205+
'20200416_160256_03078_6b4yt' --this is a trailing comment
206+
,
207+
/*
208+
this is
209+
a multiline comment
210+
*/
211+
'If he dies, he dies
212+
')
213+
"""),
214+
ImmutableList.of(),
215+
requestAnalyzerConfig))
216+
.isEqualTo("20200416_160256_03078_6b4yt");
217+
}
218+
219+
private static HttpServletRequest prepareMockRequestWithBody(String query)
220+
throws IOException
221+
{
222+
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
223+
224+
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
225+
when(request.getMethod()).thenReturn(HttpMethod.POST);
226+
when(request.getInputStream()).thenReturn(new ServletInputStream()
227+
{
228+
@Override
229+
public boolean isFinished()
230+
{
231+
return byteArrayInputStream.available() > 0;
232+
}
233+
234+
@Override
235+
public boolean isReady()
236+
{
237+
return true;
238+
}
239+
240+
@Override
241+
public void setReadListener(ReadListener readListener)
242+
{}
243+
244+
public int read()
245+
throws IOException
246+
{
247+
return byteArrayInputStream.read();
248+
}
249+
});
250+
251+
when(request.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
252+
253+
when(request.getQueryString()).thenReturn("");
254+
255+
return request;
256+
}
257+
69258
@Test
70259
public void testUserFromRequest()
71260
throws IOException

0 commit comments

Comments
 (0)