Skip to content

Commit 76f2c68

Browse files
committed
Extract query ID from all kill_query procedure variations
1 parent 23e8320 commit 76f2c68

File tree

6 files changed

+291
-37
lines changed

6 files changed

+291
-37
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/config/RequestAnalyzerConfig.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
public class RequestAnalyzerConfig
1919
{
20-
private Integer maxBodySize = 1_000_000;
20+
private int maxBodySize = 1_000_000;
2121

2222
private boolean isClientsUseV2Format;
2323
private String tokenUserField = "email";
@@ -26,13 +26,13 @@ public class RequestAnalyzerConfig
2626

2727
public RequestAnalyzerConfig() {}
2828

29-
public Integer getMaxBodySize()
29+
public int getMaxBodySize()
3030
{
3131
return maxBodySize;
3232
}
3333

3434
@Max(Integer.MAX_VALUE)
35-
public void setMaxBodySize(Integer maxBodySize)
35+
public void setMaxBodySize(int maxBodySize)
3636
{
3737
this.maxBodySize = maxBodySize;
3838
}

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
import com.google.common.base.Splitter;
1717
import com.google.common.io.CharStreams;
1818
import io.airlift.log.Logger;
19+
import io.trino.gateway.ha.router.TrinoQueryProperties;
20+
import io.trino.sql.tree.StringLiteral;
1921
import jakarta.servlet.http.HttpServletRequest;
2022

23+
import java.io.IOException;
2124
import java.io.InputStreamReader;
2225
import java.util.Base64;
2326
import java.util.List;
2427
import java.util.Optional;
2528
import java.util.regex.Matcher;
2629
import java.util.regex.Pattern;
2730

31+
import static com.google.common.base.Preconditions.checkArgument;
2832
import static com.google.common.base.Strings.isNullOrEmpty;
2933
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
3034
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
@@ -50,7 +54,6 @@ public final class ProxyUtils
5054
* capitalization.
5155
*/
5256
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
53-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
5457

5558
private ProxyUtils() {}
5659

@@ -89,31 +92,28 @@ public static String getQueryUser(String userHeader, String authorization)
8992
return parts.get(0);
9093
}
9194

92-
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
95+
public static String extractQueryIdIfPresent(
96+
HttpServletRequest request,
97+
List<String> statementPaths,
98+
boolean requestAnalyserClientsUseV2Format,
99+
int requestAnalyserMaxBodySize)
93100
{
94101
String path = request.getRequestURI();
95102
String queryParams = request.getQueryString();
103+
String queryText = null;
96104
try {
97-
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
98-
if (!isNullOrEmpty(queryText)
99-
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
100-
// extract and return the queryId
101-
String[] parts = queryText.split(",");
102-
for (String part : parts) {
103-
if (part.contains("query_id")) {
104-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
105-
if (matcher.find()) {
106-
String queryQuoted = matcher.group();
107-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
108-
return queryQuoted.substring(1, queryQuoted.length() - 1);
109-
}
110-
}
111-
}
112-
}
113-
}
105+
queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
114106
}
115-
catch (Exception e) {
116-
log.error(e, "Error extracting query payload from request");
107+
catch (IOException e) {
108+
log.error(e, "Error reading request body");
109+
}
110+
if (!isNullOrEmpty(queryText) && queryText.toLowerCase().contains("kill_query")) {
111+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
112+
if (trinoQueryProperties.getProcedure().filter(p -> p.getName().getSuffix().equals("kill_query")).isPresent()) {
113+
checkArgument(trinoQueryProperties.getProcedureArguments().getFirst().getValue() instanceof StringLiteral,
114+
"Unable to route kill_query procedures where the first argument is not a String Literal");
115+
return ((StringLiteral) trinoQueryProperties.getProcedureArguments().getFirst().getValue()).getValue();
116+
}
117117
}
118118

119119
return extractQueryIdIfPresent(path, queryParams, statementPaths);

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import io.airlift.log.Logger;
1717
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
18+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
1819
import io.trino.gateway.ha.router.GatewayCookie;
1920
import io.trino.gateway.ha.router.RoutingGroupSelector;
2021
import io.trino.gateway.ha.router.RoutingManager;
@@ -45,18 +46,24 @@ public class RoutingTargetHandler
4546
private final RoutingGroupSelector routingGroupSelector;
4647
private final List<String> statementPaths;
4748
private final List<Pattern> extraWhitelistPaths;
49+
private final boolean requestAnalyserClientsUseV2Format;
50+
private final int requestAnalyserMaxBodySize;
4851
private final boolean cookiesEnabled;
4952

5053
public RoutingTargetHandler(
5154
RoutingManager routingManager,
5255
RoutingGroupSelector routingGroupSelector,
5356
List<String> statementPaths,
54-
List<String> extraWhitelistPaths)
57+
List<String> extraWhitelistPaths,
58+
RequestAnalyzerConfig requestAnalyzerConfig)
5559
{
5660
this.routingManager = requireNonNull(routingManager);
5761
this.routingGroupSelector = requireNonNull(routingGroupSelector);
5862
this.statementPaths = requireNonNull(statementPaths);
5963
this.extraWhitelistPaths = extraWhitelistPaths.stream().map(Pattern::compile).collect(toImmutableList());
64+
requestAnalyserClientsUseV2Format = requestAnalyzerConfig.isClientsUseV2Format();
65+
requestAnalyserMaxBodySize = requestAnalyzerConfig.getMaxBodySize();
66+
6067
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
6168
}
6269

@@ -94,7 +101,7 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
94101

95102
private Optional<String> getPreviousBackend(HttpServletRequest request)
96103
{
97-
String queryId = extractQueryIdIfPresent(request, statementPaths);
104+
String queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
98105
if (!isNullOrEmpty(queryId)) {
99106
return Optional.of(routingManager.findBackendForQueryId(queryId));
100107
}

gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ public RoutingTargetHandler getRoutingTargetHandler(
211211
routingManager,
212212
routingGroupSelector,
213213
configuration.getStatementPaths(),
214-
configuration.getExtraWhitelistPaths());
214+
configuration.getExtraWhitelistPaths(),
215+
configuration.getRequestAnalyzerConfig());
215216
}
216217
}

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.fasterxml.jackson.databind.SerializerProvider;
2020
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
2121
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
22+
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.ImmutableMap;
2324
import com.google.common.collect.ImmutableSet;
2425
import io.airlift.compress.zstd.ZstdDecompressor;
@@ -29,6 +30,8 @@
2930
import io.trino.sql.parser.SqlParser;
3031
import io.trino.sql.tree.AddColumn;
3132
import io.trino.sql.tree.Analyze;
33+
import io.trino.sql.tree.Call;
34+
import io.trino.sql.tree.CallArgument;
3235
import io.trino.sql.tree.CreateCatalog;
3336
import io.trino.sql.tree.CreateMaterializedView;
3437
import io.trino.sql.tree.CreateSchema;
@@ -85,6 +88,7 @@ public class TrinoQueryProperties
8588
{
8689
private final Logger log = Logger.get(TrinoQueryProperties.class);
8790
private final boolean isClientsUseV2Format;
91+
private final int maxBodySize;
8892
private String body = "";
8993
private String queryType = "";
9094
private String resourceGroupQueryType = "";
@@ -94,6 +98,8 @@ public class TrinoQueryProperties
9498
private Set<String> catalogs = ImmutableSet.of();
9599
private Set<String> schemas = ImmutableSet.of();
96100
private Set<String> catalogSchemas = ImmutableSet.of();
101+
private Optional<Call> procedure = Optional.empty();
102+
private List<CallArgument> procedureArguments = ImmutableList.of();
97103
private boolean isNewQuerySubmission;
98104
private boolean isQueryParsingSuccessful;
99105

@@ -127,21 +133,28 @@ public TrinoQueryProperties(
127133
this.isNewQuerySubmission = isNewQuerySubmission;
128134
this.isQueryParsingSuccessful = isQueryParsingSuccessful;
129135
isClientsUseV2Format = false;
136+
maxBodySize = -1;
130137
}
131138

132139
public TrinoQueryProperties(HttpServletRequest request, RequestAnalyzerConfig config)
133140
{
134-
isClientsUseV2Format = config.isClientsUseV2Format();
141+
this(request, config.isClientsUseV2Format(), config.getMaxBodySize());
142+
}
143+
144+
public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize)
145+
{
146+
this.isClientsUseV2Format = isClientsUseV2Format;
147+
this.maxBodySize = maxBodySize;
135148

136149
defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME));
137150
defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME));
138151
if (request.getMethod().equals(HttpMethod.POST)) {
139152
isNewQuerySubmission = true;
140-
processRequestBody(request, config);
153+
processRequestBody(request);
141154
}
142155
}
143156

144-
private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfig config)
157+
private void processRequestBody(HttpServletRequest request)
145158
{
146159
try (BufferedReader reader = request.getReader()) {
147160
if (reader == null) {
@@ -152,11 +165,11 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi
152165

153166
Map<String, String> preparedStatements = getPreparedStatements(request);
154167
SqlParser parser = new SqlParser();
155-
reader.mark(config.getMaxBodySize());
156-
char[] buffer = new char[config.getMaxBodySize()];
157-
int nChars = reader.read(buffer, 0, config.getMaxBodySize());
168+
reader.mark(maxBodySize);
169+
char[] buffer = new char[maxBodySize];
170+
int nChars = reader.read(buffer, 0, maxBodySize);
158171
reader.reset();
159-
if (nChars == config.getMaxBodySize()) {
172+
if (nChars == maxBodySize) {
160173
log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected");
161174
return;
162175
//The body is truncated - there is a chance that it could still be syntactically valid SQL, for example if truncated on
@@ -194,7 +207,7 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi
194207
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
195208
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
196209

197-
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
210+
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
198211
tables = tableBuilder.build();
199212
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
200213
catalogs = catalogBuilder.build();
@@ -256,7 +269,7 @@ private String decodePreparedStatementFromHeader(String headerValue)
256269
return new String(preparedStatement, UTF_8);
257270
}
258271

259-
private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
272+
private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
260273
ImmutableSet.Builder<String> catalogBuilder,
261274
ImmutableSet.Builder<String> schemaBuilder,
262275
ImmutableSet.Builder<String> catalogSchemaBuilder)
@@ -265,6 +278,11 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
265278
switch (node) {
266279
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
267280
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
281+
case Call call -> {
282+
procedure = Optional.of(call);
283+
procedureArguments = call.getArguments();
284+
return;
285+
}
268286
case CreateCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
269287
case CreateMaterializedView s -> tableBuilder.add(qualifyName(s.getName()));
270288
case CreateSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
@@ -338,7 +356,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
338356
}
339357

340358
for (Node child : node.getChildren()) {
341-
getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
359+
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
342360
}
343361
}
344362

@@ -513,6 +531,16 @@ public boolean isQueryParsingSuccessful()
513531
return isQueryParsingSuccessful;
514532
}
515533

534+
public Optional<Call> getProcedure()
535+
{
536+
return procedure;
537+
}
538+
539+
public List<CallArgument> getProcedureArguments()
540+
{
541+
return procedureArguments;
542+
}
543+
516544
public static class AlternateStatementRequestBodyFormat
517545
{
518546
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session

0 commit comments

Comments
 (0)