Skip to content

Commit e9226b6

Browse files
committed
Extract query ID from all kill_query procedure variations
1 parent 23e8320 commit e9226b6

File tree

2 files changed

+119
-14
lines changed

2 files changed

+119
-14
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
3030
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
3131
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;
32+
import static java.util.regex.Pattern.CASE_INSENSITIVE;
3233

3334
public final class ProxyUtils
3435
{
@@ -50,7 +51,14 @@ public final class ProxyUtils
5051
* capitalization.
5152
*/
5253
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
53-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
54+
55+
/**
56+
* This regular expression extracts the query id from a CALL system.runtime.kill_query procedure call. It extracts the first string between quotes
57+
* following the open parentheses after system.runtime.kill_query. The pattern handles the optional named arguments, and optional message argument
58+
* as well as arbitrary whitespace.
59+
*/
60+
private static final Pattern KILL_QUERY_PROCEDURE_PATTERN
61+
= Pattern.compile(".*?kill_query\\s*\\(\\s*(query_id\\s*=>)?\\s*'([^\\\\s]+?)'(,\\s*(message\\s*=>\\s*)?('.*'))?\\)", CASE_INSENSITIVE);
5462

5563
private ProxyUtils() {}
5664

@@ -96,19 +104,10 @@ public static String extractQueryIdIfPresent(HttpServletRequest request, List<St
96104
try {
97105
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
98106
if (!isNullOrEmpty(queryText)
99-
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
100-
// extract and return the queryId
101-
String[] parts = queryText.split(",");
102-
for (String part : parts) {
103-
if (part.contains("query_id")) {
104-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
105-
if (matcher.find()) {
106-
String queryQuoted = matcher.group();
107-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
108-
return queryQuoted.substring(1, queryQuoted.length() - 1);
109-
}
110-
}
111-
}
107+
&& queryText.toLowerCase().contains("kill_query")) {
108+
Matcher matcher = KILL_QUERY_PROCEDURE_PATTERN.matcher(queryText.toLowerCase());
109+
if (matcher.find()) {
110+
return matcher.group(2);
112111
}
113112
}
114113
}

gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,25 @@
1414
package io.trino.gateway.ha.handler;
1515

1616
import com.google.common.collect.ImmutableList;
17+
import jakarta.servlet.ReadListener;
18+
import jakarta.servlet.ServletInputStream;
1719
import jakarta.servlet.http.HttpServletRequest;
1820
import org.junit.jupiter.api.Test;
1921
import org.junit.jupiter.api.TestInstance;
2022
import org.junit.jupiter.api.TestInstance.Lifecycle;
2123
import org.mockito.Mockito;
2224

25+
import java.io.ByteArrayInputStream;
2326
import java.io.IOException;
2427
import java.util.List;
2528

2629
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
2730
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
2831
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
2932
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
33+
import static java.nio.charset.StandardCharsets.UTF_8;
3034
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.mockito.Mockito.when;
3136

3237
@TestInstance(Lifecycle.PER_CLASS)
3338
public class TestQueryIdCachingProxyHandler
@@ -66,6 +71,107 @@ public void testExtractQueryIdFromUrl()
6671
.isNull();
6772
}
6873

74+
@Test
75+
void testQueryIdFromKill()
76+
throws IOException
77+
{
78+
assertThat(
79+
extractQueryIdIfPresent(
80+
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"),
81+
ImmutableList.of()))
82+
.isEqualTo("20200416_160256_03078_6b4yt");
83+
84+
assertThat(
85+
extractQueryIdIfPresent(
86+
prepareMockRequestWithBody("CALL system.runtime.kill_query(Query_id => '20200416_160256_03078_6b4yt', Message => 'If he dies, he dies')"),
87+
ImmutableList.of()))
88+
.isEqualTo("20200416_160256_03078_6b4yt");
89+
90+
assertThat(
91+
extractQueryIdIfPresent(
92+
prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')"),
93+
ImmutableList.of()))
94+
.isEqualTo("20200416_160256_03078_6b4yt");
95+
96+
assertThat(
97+
extractQueryIdIfPresent(
98+
prepareMockRequestWithBody("CALL runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"), ImmutableList.of()))
99+
.isEqualTo("20200416_160256_03078_6b4yt");
100+
101+
assertThat(
102+
extractQueryIdIfPresent(
103+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_7n5uy'')')"),
104+
ImmutableList.of()))
105+
.isEqualTo("20200416_160256_03078_6b4yt");
106+
107+
assertThat(
108+
extractQueryIdIfPresent(
109+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"), ImmutableList.of()))
110+
.isEqualTo("20200416_160256_03078_6b4yt");
111+
112+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')"), ImmutableList.of()))
113+
.isEqualTo("20200416_160256_03078_6b4yt");
114+
115+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')"), ImmutableList.of()))
116+
.isEqualTo("20200416_160256_03078_6b4yt");
117+
118+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL kill_query('20200416_160256_03078_6b4yt')"), ImmutableList.of()))
119+
.isEqualTo("20200416_160256_03078_6b4yt");
120+
121+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("call Kill_Query('20200416_160256_03078_6b4yt')"), ImmutableList.of()))
122+
.isEqualTo("20200416_160256_03078_6b4yt");
123+
124+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE '%kill_query(''20200416_160256%' "),
125+
ImmutableList.of()))
126+
.isNull();
127+
128+
assertThat(extractQueryIdIfPresent(
129+
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE '%kill_query(''20200416_160256_03078_6b4yt' "),
130+
ImmutableList.of()))
131+
.isNull();
132+
133+
assertThat(extractQueryIdIfPresent(
134+
prepareMockRequestWithBody("select * from postgres.query_logs.queries where sql LIKE 'CALL kill_query(_20200416_160256_03078_6b4yt_)' "),
135+
ImmutableList.of()))
136+
.isNull();
137+
}
138+
139+
private static HttpServletRequest prepareMockRequestWithBody(String query)
140+
throws IOException
141+
{
142+
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
143+
144+
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
145+
when(request.getInputStream()).thenReturn(new ServletInputStream()
146+
{
147+
@Override
148+
public boolean isFinished()
149+
{
150+
return byteArrayInputStream.available() > 0;
151+
}
152+
153+
@Override
154+
public boolean isReady()
155+
{
156+
return true;
157+
}
158+
159+
@Override
160+
public void setReadListener(ReadListener readListener)
161+
{}
162+
163+
public int read()
164+
throws IOException
165+
{
166+
return byteArrayInputStream.read();
167+
}
168+
});
169+
170+
when(request.getQueryString()).thenReturn("");
171+
172+
return request;
173+
}
174+
69175
@Test
70176
public void testUserFromRequest()
71177
throws IOException

0 commit comments

Comments
 (0)