Skip to content

Commit 83760bd

Browse files
committed
Test extraction of query ID from kill_query procedure
1 parent 5afabe0 commit 83760bd

File tree

2 files changed

+72
-13
lines changed

2 files changed

+72
-13
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ public final class ProxyUtils
5050
* capitalization.
5151
*/
5252
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
53-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
53+
private static final Pattern KILL_QUERY_PROCEDURE_PATTERN
54+
= Pattern.compile(".*system\\.runtime\\.kill_query\\s*\\(\\s*(query_id\\s*=>)?\\s*'([^\\\\s]+)'(,\\s*(message\\s*=>\\s*)?('.*'))?\\)");
55+
/**
56+
* This regular expression extracts the query id from a CALL system.runtime.kill_query procedure call. It extracts the first string between quotes
57+
* following the open parentheses after system.runtime.kill_query. The pattern handles the optional named arguments, and optional message argument
58+
* as well as arbitrary whitespace.
59+
*/
5460

5561
private ProxyUtils() {}
5662

@@ -97,18 +103,9 @@ public static String extractQueryIdIfPresent(HttpServletRequest request)
97103
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
98104
if (!isNullOrEmpty(queryText)
99105
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
100-
// extract and return the queryId
101-
String[] parts = queryText.split(",");
102-
for (String part : parts) {
103-
if (part.contains("query_id")) {
104-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
105-
if (matcher.find()) {
106-
String queryQuoted = matcher.group();
107-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
108-
return queryQuoted.substring(1, queryQuoted.length() - 1);
109-
}
110-
}
111-
}
106+
Matcher matcher = KILL_QUERY_PROCEDURE_PATTERN.matcher(queryText.toLowerCase());
107+
if (matcher.find()) {
108+
return matcher.group(2);
112109
}
113110
}
114111
}

gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
*/
1414
package io.trino.gateway.ha.handler;
1515

16+
import jakarta.servlet.ReadListener;
17+
import jakarta.servlet.ServletInputStream;
1618
import jakarta.servlet.http.HttpServletRequest;
1719
import org.junit.jupiter.api.Test;
1820
import org.junit.jupiter.api.TestInstance;
1921
import org.junit.jupiter.api.TestInstance.Lifecycle;
2022
import org.mockito.Mockito;
2123

24+
import java.io.ByteArrayInputStream;
2225
import java.io.IOException;
2326

2427
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
2528
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
2629
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.AUTHORIZATION;
2730
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
31+
import static java.nio.charset.StandardCharsets.UTF_8;
2832
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.Mockito.when;
2934

3035
@TestInstance(Lifecycle.PER_CLASS)
3136
public class TestQueryIdCachingProxyHandler
@@ -61,6 +66,63 @@ public void testExtractQueryIdFromUrl()
6166
.isNull();
6267
}
6368

69+
@Test
70+
void testQueryIdFromKill()
71+
throws IOException
72+
{
73+
assertThat(
74+
extractQueryIdIfPresent(
75+
prepareMockRequestWithBody("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')")))
76+
.isEqualTo("20200416_160256_03078_6b4yt");
77+
78+
assertThat(
79+
extractQueryIdIfPresent(
80+
prepareMockRequestWithBody("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')")))
81+
.isEqualTo("20200416_160256_03078_6b4yt");
82+
83+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query( query_id=>'20200416_160256_03078_6b4yt')")))
84+
.isEqualTo("20200416_160256_03078_6b4yt");
85+
86+
assertThat(extractQueryIdIfPresent(prepareMockRequestWithBody("CALL system.runtime.kill_query( '20200416_160256_03078_6b4yt')")))
87+
.isEqualTo("20200416_160256_03078_6b4yt");
88+
}
89+
90+
HttpServletRequest prepareMockRequestWithBody(String query)
91+
throws IOException
92+
{
93+
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
94+
95+
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8));
96+
when(request.getInputStream()).thenReturn(new ServletInputStream()
97+
{
98+
@Override
99+
public boolean isFinished()
100+
{
101+
return byteArrayInputStream.available() > 0;
102+
}
103+
104+
@Override
105+
public boolean isReady()
106+
{
107+
return true;
108+
}
109+
110+
@Override
111+
public void setReadListener(ReadListener readListener)
112+
{}
113+
114+
public int read()
115+
throws IOException
116+
{
117+
return byteArrayInputStream.read();
118+
}
119+
});
120+
121+
when(request.getQueryString()).thenReturn("");
122+
123+
return request;
124+
}
125+
64126
@Test
65127
public void testUserFromRequest()
66128
throws IOException

0 commit comments

Comments
 (0)