Skip to content

Commit 7c8c140

Browse files
willmostlyebyhr
andcommitted
Extract query ID from all kill_query procedure variations
Co-authored-by: Yuya Ebihara <[email protected]>
1 parent 2e3c49d commit 7c8c140

File tree

9 files changed

+316
-83
lines changed

9 files changed

+316
-83
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
import com.google.common.io.CharStreams;
1717
import io.airlift.log.Logger;
18+
import io.trino.gateway.ha.router.TrinoQueryProperties;
1819
import jakarta.servlet.http.HttpServletRequest;
20+
import jakarta.ws.rs.HttpMethod;
1921

22+
import java.io.IOException;
2023
import java.io.InputStreamReader;
2124
import java.util.List;
2225
import java.util.Optional;
@@ -49,51 +52,50 @@ public final class ProxyUtils
4952
* capitalization.
5053
*/
5154
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
52-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
5355

5456
private ProxyUtils() {}
5557

56-
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
58+
public static Optional<String> extractQueryIdIfPresent(
59+
HttpServletRequest request,
60+
List<String> statementPaths,
61+
boolean requestAnalyserClientsUseV2Format,
62+
int requestAnalyserMaxBodySize)
5763
{
5864
String path = request.getRequestURI();
5965
String queryParams = request.getQueryString();
66+
if (!request.getMethod().equals(HttpMethod.POST)) {
67+
return extractQueryIdIfPresent(path, queryParams, statementPaths);
68+
}
69+
String queryText;
6070
try {
61-
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
62-
if (!isNullOrEmpty(queryText)
63-
&& queryText.toLowerCase(ENGLISH).contains("system.runtime.kill_query")) {
64-
// extract and return the queryId
65-
String[] parts = queryText.split(",");
66-
for (String part : parts) {
67-
if (part.contains("query_id")) {
68-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
69-
if (matcher.find()) {
70-
String queryQuoted = matcher.group();
71-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
72-
return queryQuoted.substring(1, queryQuoted.length() - 1);
73-
}
74-
}
75-
}
76-
}
77-
}
71+
queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
7872
}
79-
catch (Exception e) {
80-
log.error(e, "Error extracting query payload from request");
73+
catch (IOException e) {
74+
throw new RuntimeException("Error reading request body", e);
8175
}
82-
83-
return extractQueryIdIfPresent(path, queryParams, statementPaths);
76+
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
77+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
78+
return trinoQueryProperties.getQueryId();
79+
}
80+
return Optional.empty();
8481
}
8582

86-
public static String extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
83+
public static Optional<String> extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
8784
{
8885
if (path == null) {
89-
return null;
86+
return Optional.empty();
9087
}
91-
String queryId = null;
9288
log.debug("Trying to extract query id from path [%s] or queryString [%s]", path, queryParams);
9389
// matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number,
9490
// and if custom paths are supplied using the statementPaths configuration, paths such as
9591
// /custom/statement/path/executing/query_id/nonce/sequence_number
9692
Optional<String> matchingStatementPath = statementPaths.stream().filter(path::startsWith).findAny();
93+
if (!isNullOrEmpty(queryParams)) {
94+
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
95+
if (matcher.matches()) {
96+
return Optional.of(matcher.group(1));
97+
}
98+
}
9799
if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) {
98100
path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), "");
99101
String[] tokens = path.split("/");
@@ -102,27 +104,20 @@ public static String extractQueryIdIfPresent(String path, String queryParams, Li
102104
|| tokens[1].equals("scheduled")
103105
|| tokens[1].equals("executing")
104106
|| tokens[1].equals("partialCancel")) {
105-
queryId = tokens[2];
107+
return Optional.of(tokens[2]);
106108
}
107109
else {
108-
queryId = tokens[1];
110+
return Optional.of(tokens[1]);
109111
}
110112
}
111113
}
112114
else if (path.startsWith(TRINO_UI_PATH)) {
113115
Matcher matcher = QUERY_ID_PATH_PATTERN.matcher(path);
114116
if (matcher.matches()) {
115-
queryId = matcher.group(1);
116-
}
117-
}
118-
if (!isNullOrEmpty(queryParams)) {
119-
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
120-
if (matcher.matches()) {
121-
queryId = matcher.group(1);
117+
return Optional.of(matcher.group(1));
122118
}
123119
}
124-
log.debug("Query id in URL [%s]", queryId);
125-
return queryId;
120+
return Optional.empty();
126121
}
127122

128123
public static String buildUriWithNewBackend(String backendHost, HttpServletRequest request)

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public class RoutingTargetHandler
4747
private final RoutingGroupSelector routingGroupSelector;
4848
private final List<String> statementPaths;
4949
private final List<Pattern> extraWhitelistPaths;
50+
private final boolean requestAnalyserClientsUseV2Format;
51+
private final int requestAnalyserMaxBodySize;
5052
private final boolean cookiesEnabled;
5153

5254
@Inject
@@ -57,8 +59,10 @@ public RoutingTargetHandler(
5759
{
5860
this.routingManager = requireNonNull(routingManager);
5961
this.routingGroupSelector = requireNonNull(routingGroupSelector);
60-
this.statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
61-
this.extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
62+
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
63+
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
64+
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
65+
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
6266
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
6367
}
6468

@@ -96,9 +100,9 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
96100

97101
private Optional<String> getPreviousBackend(HttpServletRequest request)
98102
{
99-
String queryId = extractQueryIdIfPresent(request, statementPaths);
100-
if (!isNullOrEmpty(queryId)) {
101-
return Optional.of(routingManager.findBackendForQueryId(queryId));
103+
Optional<String> queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
104+
if (queryId.isPresent()) {
105+
return queryId.map(routingManager::findBackendForQueryId);
102106
}
103107
if (cookiesEnabled && request.getCookies() != null) {
104108
List<GatewayCookie> cookies = Arrays.stream(request.getCookies())

gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
118118
TrinoQueryProperties trinoQueryProperties = null;
119119
TrinoRequestUser trinoRequestUser = null;
120120
if (requestAnalyzerConfig.isAnalyzeRequest()) {
121-
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
121+
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
122122
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
123123
}
124124

gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ public String findRoutingGroup(HttpServletRequest request)
9696

9797
facts.put("request", request);
9898
if (requestAnalyzerConfig.isAnalyzeRequest()) {
99-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
99+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
100+
request,
101+
requestAnalyzerConfig.isClientsUseV2Format(),
102+
requestAnalyzerConfig.getMaxBodySize());
100103
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
101104
facts.put("trinoQueryProperties", trinoQueryProperties);
102105
facts.put("trinoRequestUser", trinoRequestUser);

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,23 @@
1414
package io.trino.gateway.ha.router;
1515

1616
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonIgnore;
1718
import com.fasterxml.jackson.annotation.JsonProperty;
1819
import com.fasterxml.jackson.core.JsonGenerator;
1920
import com.fasterxml.jackson.databind.SerializerProvider;
2021
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
2122
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
23+
import com.google.common.collect.ImmutableList;
2224
import com.google.common.collect.ImmutableMap;
2325
import com.google.common.collect.ImmutableSet;
2426
import io.airlift.compress.zstd.ZstdDecompressor;
2527
import io.airlift.json.JsonCodec;
2628
import io.airlift.log.Logger;
27-
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
2829
import io.trino.sql.parser.ParsingException;
2930
import io.trino.sql.parser.SqlParser;
3031
import io.trino.sql.tree.AddColumn;
3132
import io.trino.sql.tree.Analyze;
33+
import io.trino.sql.tree.Call;
3234
import io.trino.sql.tree.CreateCatalog;
3335
import io.trino.sql.tree.CreateMaterializedView;
3436
import io.trino.sql.tree.CreateSchema;
@@ -40,6 +42,7 @@
4042
import io.trino.sql.tree.DropTable;
4143
import io.trino.sql.tree.Execute;
4244
import io.trino.sql.tree.ExecuteImmediate;
45+
import io.trino.sql.tree.Expression;
4346
import io.trino.sql.tree.Identifier;
4447
import io.trino.sql.tree.Node;
4548
import io.trino.sql.tree.NodeLocation;
@@ -57,6 +60,7 @@
5760
import io.trino.sql.tree.ShowSchemas;
5861
import io.trino.sql.tree.ShowTables;
5962
import io.trino.sql.tree.Statement;
63+
import io.trino.sql.tree.StringLiteral;
6064
import io.trino.sql.tree.Table;
6165
import io.trino.sql.tree.TableFunctionInvocation;
6266
import jakarta.servlet.http.HttpServletRequest;
@@ -73,6 +77,7 @@
7377
import java.util.Set;
7478
import java.util.stream.Collectors;
7579

80+
import static com.google.common.base.Preconditions.checkArgument;
7681
import static com.google.common.io.BaseEncoding.base64Url;
7782
import static io.airlift.json.JsonCodec.jsonCodec;
7883
import static java.lang.Math.toIntExact;
@@ -85,6 +90,7 @@ public class TrinoQueryProperties
8590
{
8691
private final Logger log = Logger.get(TrinoQueryProperties.class);
8792
private final boolean isClientsUseV2Format;
93+
private final int maxBodySize;
8894
private String body = "";
8995
private String queryType = "";
9096
private String resourceGroupQueryType = "";
@@ -96,6 +102,7 @@ public class TrinoQueryProperties
96102
private Set<String> catalogSchemas = ImmutableSet.of();
97103
private boolean isNewQuerySubmission;
98104
private Optional<String> errorMessage = Optional.empty();
105+
private Optional<String> queryId = Optional.empty();
99106

100107
public static final String TRINO_CATALOG_HEADER_NAME = "X-Trino-Catalog";
101108
public static final String TRINO_SCHEMA_HEADER_NAME = "X-Trino-Schema";
@@ -128,21 +135,24 @@ public TrinoQueryProperties(
128135
this.isNewQuerySubmission = isNewQuerySubmission;
129136
this.errorMessage = requireNonNullElse(errorMessage, Optional.empty());
130137
isClientsUseV2Format = false;
138+
maxBodySize = -1;
131139
}
132140

133-
public TrinoQueryProperties(HttpServletRequest request, RequestAnalyzerConfig config)
141+
public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize)
134142
{
135-
isClientsUseV2Format = config.isClientsUseV2Format();
143+
requireNonNull(request, "request is null");
144+
this.isClientsUseV2Format = isClientsUseV2Format;
145+
this.maxBodySize = maxBodySize;
136146

137147
defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME));
138148
defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME));
139149
if (request.getMethod().equals(HttpMethod.POST)) {
140150
isNewQuerySubmission = true;
141-
processRequestBody(request, config);
151+
processRequestBody(request);
142152
}
143153
}
144154

145-
private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfig config)
155+
private void processRequestBody(HttpServletRequest request)
146156
{
147157
try (BufferedReader reader = request.getReader()) {
148158
if (reader == null) {
@@ -153,11 +163,11 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi
153163

154164
Map<String, String> preparedStatements = getPreparedStatements(request);
155165
SqlParser parser = new SqlParser();
156-
reader.mark(config.getMaxBodySize());
157-
char[] buffer = new char[config.getMaxBodySize()];
158-
int nChars = reader.read(buffer, 0, config.getMaxBodySize());
166+
reader.mark(maxBodySize);
167+
char[] buffer = new char[maxBodySize];
168+
int nChars = reader.read(buffer, 0, maxBodySize);
159169
reader.reset();
160-
if (nChars == config.getMaxBodySize()) {
170+
if (nChars == maxBodySize) {
161171
log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected");
162172
return;
163173
//The body is truncated - there is a chance that it could still be syntactically valid SQL, for example if truncated on
@@ -199,7 +209,7 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
199209
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
200210
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
201211

202-
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
212+
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
203213
tables = tableBuilder.build();
204214
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
205215
catalogs = catalogBuilder.build();
@@ -260,7 +270,7 @@ private String decodePreparedStatementFromHeader(String headerValue)
260270
return new String(preparedStatement, UTF_8);
261271
}
262272

263-
private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
273+
private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
264274
ImmutableSet.Builder<String> catalogBuilder,
265275
ImmutableSet.Builder<String> schemaBuilder,
266276
ImmutableSet.Builder<String> catalogSchemaBuilder)
@@ -269,6 +279,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
269279
switch (node) {
270280
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
271281
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
282+
case Call call -> queryId = extractQueryIdFromCall(call);
272283
case CreateCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
273284
case CreateMaterializedView s -> tableBuilder.add(qualifyName(s.getName()));
274285
case CreateSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
@@ -342,10 +353,22 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
342353
}
343354

344355
for (Node child : node.getChildren()) {
345-
getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
356+
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
346357
}
347358
}
348359

360+
private Optional<String> extractQueryIdFromCall(Call call)
361+
throws RequestParsingException
362+
{
363+
QualifiedName callName = qualifyName(call.getName());
364+
if (callName.equals(QualifiedName.of("system", "runtime", "kill_query"))) {
365+
Expression argument = call.getArguments().getFirst().getValue();
366+
checkArgument(argument instanceof StringLiteral, "Unable to route kill_query procedures where the first argument is not a String Literal");
367+
return Optional.of(((StringLiteral) argument).getValue());
368+
}
369+
return Optional.empty();
370+
}
371+
349372
private void setCatalogAndSchemaNameFromSchemaQualifiedName(
350373
Optional<QualifiedName> schemaOptional,
351374
ImmutableSet.Builder<String> catalogBuilder,
@@ -381,15 +404,16 @@ private RequestParsingException unsetDefaultExceptionSupplier()
381404
return new RequestParsingException("Name not fully qualified");
382405
}
383406

384-
private QualifiedName qualifyName(QualifiedName table)
407+
private QualifiedName qualifyName(QualifiedName name)
385408
throws RequestParsingException
386409
{
387-
List<String> tableParts = table.getParts();
388-
return switch (tableParts.size()) {
389-
case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst());
390-
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1));
391-
case 3 -> QualifiedName.of(tableParts.getFirst(), tableParts.get(1), tableParts.get(2));
392-
default -> throw new RequestParsingException("Unexpected table name: " + table.getParts());
410+
List<String> nameParts = name.getParts();
411+
return switch (nameParts.size()) {
412+
case 1 ->
413+
QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst());
414+
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst(), nameParts.get(1));
415+
case 3 -> QualifiedName.of(nameParts.getFirst(), nameParts.get(1), nameParts.get(2));
416+
default -> throw new RequestParsingException("Unexpected qualified name: " + name.getParts());
393417
};
394418
}
395419

@@ -520,6 +544,12 @@ public Optional<String> getErrorMessage()
520544
return errorMessage;
521545
}
522546

547+
@JsonIgnore
548+
public Optional<String> getQueryId()
549+
{
550+
return queryId;
551+
}
552+
523553
public static class AlternateStatementRequestBodyFormat
524554
{
525555
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session

0 commit comments

Comments
 (0)