Skip to content

Commit c311199

Browse files
committed
Extract query ID from all kill_query procedure variations
1 parent 5afabe0 commit c311199

File tree

2 files changed

+110
-14
lines changed

2 files changed

+110
-14
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
3030
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;
3131
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_STATEMENT_PATH;
32+
import static java.util.regex.Pattern.CASE_INSENSITIVE;
3233

3334
public final class ProxyUtils
3435
{
@@ -50,7 +51,14 @@ public final class ProxyUtils
5051
* capitalization.
5152
*/
5253
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
53-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
54+
55+
/**
56+
* This regular expression extracts the query id from a CALL system.runtime.kill_query procedure call. It extracts the first string between quotes
57+
* following the open parentheses after system.runtime.kill_query. The pattern handles the optional named arguments, and optional message argument
58+
* as well as arbitrary whitespace.
59+
*/
60+
private static final Pattern KILL_QUERY_PROCEDURE_PATTERN
61+
= Pattern.compile(".*?kill_query\\s*\\(\\s*(query_id\\s*=>)?\\s*'([^\\\\s]+?)'(,\\s*(message\\s*=>\\s*)?('.*'))?\\)", CASE_INSENSITIVE);
5462

5563
private ProxyUtils() {}
5664

@@ -96,19 +104,10 @@ public static String extractQueryIdIfPresent(HttpServletRequest request)
96104
try {
97105
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
98106
if (!isNullOrEmpty(queryText)
99-
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
100-
// extract and return the queryId
101-
String[] parts = queryText.split(",");
102-
for (String part : parts) {
103-
if (part.contains("query_id")) {
104-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
105-
if (matcher.find()) {
106-
String queryQuoted = matcher.group();
107-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
108-
return queryQuoted.substring(1, queryQuoted.length() - 1);
109-
}
110-
}
111-
}
107+
&& queryText.toLowerCase().contains("kill_query")) {
108+
Matcher matcher = KILL_QUERY_PROCEDURE_PATTERN.matcher(queryText.toLowerCase());
109+
if (matcher.find()) {
110+
return matcher.group(2);
112111
}
113112
}
114113
}

gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
*/
1414
package io.trino.gateway.ha.handler;
1515

16+
import jakarta.servlet.ReadListener;
17+
import jakarta.servlet.ServletInputStream;
1618
import jakarta.servlet.http.HttpServletRequest;
1719
import org.junit.jupiter.api.Test;
1820
import org.junit.jupiter.api.TestInstance;
1921
import org.junit.jupiter.api.TestInstance.Lifecycle;
2022
import org.mockito.Mockito;
2123

24+
import java.io.ByteArrayInputStream;
2225
import java.io.IOException;
2326

2427
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
2528
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
2629
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
2730
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
31+
import static java.nio.charset.StandardCharsets.UTF_8;
2832
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.Mockito.when;
2934

3035
@TestInstance(Lifecycle.PER_CLASS)
3136
public class TestQueryIdCachingProxyHandler
@@ -61,6 +66,98 @@ public void testExtractQueryIdFromUrl()
6166
.isNull();
6267
}
6368

69+
@Test
70+
void testQueryIdFromKill()
71+
throws IOException
72+
{
73+
assertThat(
74+
extractQueryIdIfPresent(
75+
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')")))
76+
.isEqualTo("20200416_160256_03078_6b4yt");
77+
78+
assertThat(
79+
extractQueryIdIfPresent(
80+
prepareMockRequestWithBody("CALL system.runtime.kill_query(Query_id => '20200416_160256_03078_6b4yt', Message => 'If he dies, he dies')")))
81+
.isEqualTo("20200416_160256_03078_6b4yt");
82+
83+
assertThat(
84+
extractQueryIdIfPresent(
85+
prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')")))
86+
.isEqualTo("20200416_160256_03078_6b4yt");
87+
88+
assertThat(
89+
extractQueryIdIfPresent(
90+
prepareMockRequestWithBody("CALL runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')")))
91+
.isEqualTo("20200416_160256_03078_6b4yt");
92+
93+
assertThat(
94+
extractQueryIdIfPresent(
95+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_7n5uy'')')")))
96+
.isEqualTo("20200416_160256_03078_6b4yt");
97+
98+
assertThat(
99+
extractQueryIdIfPresent(
100+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')")))
101+
.isEqualTo("20200416_160256_03078_6b4yt");
102+
103+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')")))
104+
.isEqualTo("20200416_160256_03078_6b4yt");
105+
106+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')")))
107+
.isEqualTo("20200416_160256_03078_6b4yt");
108+
109+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt')")))
110+
.isEqualTo("20200416_160256_03078_6b4yt");
111+
112+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("call Kill_Query('20200416_160256_03078_6b4yt')")))
113+
.isEqualTo("20200416_160256_03078_6b4yt");
114+
115+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE '%kill_query(''20200416_160256%' ")))
116+
.isNull();
117+
118+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE '%kill_query(''20200416_160256_03078_6b4yt' ")))
119+
.isNull();
120+
121+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE 'CALL kill_query(_20200416_160256_03078_6b4yt_)' ")))
122+
.isNull();
123+
}
124+
125+
private static HttpServletRequest prepareMockRequestWithBody(String query)
126+
throws IOException
127+
{
128+
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
129+
130+
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
131+
when(request.getInputStream()).thenReturn(new ServletInputStream()
132+
{
133+
@Override
134+
public boolean isFinished()
135+
{
136+
return byteArrayInputStream.available() > 0;
137+
}
138+
139+
@Override
140+
public boolean isReady()
141+
{
142+
return true;
143+
}
144+
145+
@Override
146+
public void setReadListener(ReadListener readListener)
147+
{}
148+
149+
public int read()
150+
throws IOException
151+
{
152+
return byteArrayInputStream.read();
153+
}
154+
});
155+
156+
when(request.getQueryString()).thenReturn("");
157+
158+
return request;
159+
}
160+
64161
@Test
65162
public void testUserFromRequest()
66163
throws IOException

0 commit comments

Comments
 (0)