Skip to content

Commit cf2e6ad

Browse files
willmostlyebyhr
andcommitted
Extract query ID from all kill_query procedure variations
Co-authored-by: Yuya Ebihara <[email protected]>
1 parent 2e3c49d commit cf2e6ad

File tree

9 files changed

+316
-82
lines changed

9 files changed

+316
-82
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
import com.google.common.io.CharStreams;
1717
import io.airlift.log.Logger;
18+
import io.trino.gateway.ha.router.TrinoQueryProperties;
1819
import jakarta.servlet.http.HttpServletRequest;
20+
import jakarta.ws.rs.HttpMethod;
1921

22+
import java.io.IOException;
2023
import java.io.InputStreamReader;
2124
import java.util.List;
2225
import java.util.Optional;
@@ -49,51 +52,50 @@ public final class ProxyUtils
4952
* capitalization.
5053
*/
5154
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
52-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
5355

5456
private ProxyUtils() {}
5557

56-
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
58+
public static Optional<String> extractQueryIdIfPresent(
59+
HttpServletRequest request,
60+
List<String> statementPaths,
61+
boolean requestAnalyserClientsUseV2Format,
62+
int requestAnalyserMaxBodySize)
5763
{
5864
String path = request.getRequestURI();
5965
String queryParams = request.getQueryString();
66+
if (!request.getMethod().equals(HttpMethod.POST)) {
67+
return extractQueryIdIfPresent(path, queryParams, statementPaths);
68+
}
69+
String queryText;
6070
try {
61-
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
62-
if (!isNullOrEmpty(queryText)
63-
&& queryText.toLowerCase(ENGLISH).contains("system.runtime.kill_query")) {
64-
// extract and return the queryId
65-
String[] parts = queryText.split(",");
66-
for (String part : parts) {
67-
if (part.contains("query_id")) {
68-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
69-
if (matcher.find()) {
70-
String queryQuoted = matcher.group();
71-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
72-
return queryQuoted.substring(1, queryQuoted.length() - 1);
73-
}
74-
}
75-
}
76-
}
77-
}
71+
queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
7872
}
79-
catch (Exception e) {
80-
log.error(e, "Error extracting query payload from request");
73+
catch (IOException e) {
74+
throw new RuntimeException("Error reading request body", e);
8175
}
82-
83-
return extractQueryIdIfPresent(path, queryParams, statementPaths);
76+
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
77+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
78+
return trinoQueryProperties.getQueryId();
79+
}
80+
return Optional.empty();
8481
}
8582

86-
public static String extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
83+
public static Optional<String> extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
8784
{
8885
if (path == null) {
89-
return null;
86+
return Optional.empty();
9087
}
91-
String queryId = null;
9288
log.debug("Trying to extract query id from path [%s] or queryString [%s]", path, queryParams);
9389
// matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number,
9490
// and if custom paths are supplied using the statementPaths configuration, paths such as
9591
// /custom/statement/path/executing/query_id/nonce/sequence_number
9692
Optional<String> matchingStatementPath = statementPaths.stream().filter(path::startsWith).findAny();
93+
if (!isNullOrEmpty(queryParams)) {
94+
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
95+
if (matcher.matches()) {
96+
return Optional.of(matcher.group(1));
97+
}
98+
}
9799
if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) {
98100
path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), "");
99101
String[] tokens = path.split("/");
@@ -102,27 +104,20 @@ public static String extractQueryIdIfPresent(String path, String queryParams, Li
102104
|| tokens[1].equals("scheduled")
103105
|| tokens[1].equals("executing")
104106
|| tokens[1].equals("partialCancel")) {
105-
queryId = tokens[2];
107+
return Optional.of(tokens[2]);
106108
}
107109
else {
108-
queryId = tokens[1];
110+
return Optional.of(tokens[1]);
109111
}
110112
}
111113
}
112114
else if (path.startsWith(TRINO_UI_PATH)) {
113115
Matcher matcher = QUERY_ID_PATH_PATTERN.matcher(path);
114116
if (matcher.matches()) {
115-
queryId = matcher.group(1);
116-
}
117-
}
118-
if (!isNullOrEmpty(queryParams)) {
119-
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
120-
if (matcher.matches()) {
121-
queryId = matcher.group(1);
117+
return Optional.of(matcher.group(1));
122118
}
123119
}
124-
log.debug("Query id in URL [%s]", queryId);
125-
return queryId;
120+
return Optional.empty();
126121
}
127122

128123
public static String buildUriWithNewBackend(String backendHost, HttpServletRequest request)

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public class RoutingTargetHandler
4747
private final RoutingGroupSelector routingGroupSelector;
4848
private final List<String> statementPaths;
4949
private final List<Pattern> extraWhitelistPaths;
50+
private final boolean requestAnalyserClientsUseV2Format;
51+
private final int requestAnalyserMaxBodySize;
5052
private final boolean cookiesEnabled;
5153

5254
@Inject
@@ -57,8 +59,10 @@ public RoutingTargetHandler(
5759
{
5860
this.routingManager = requireNonNull(routingManager);
5961
this.routingGroupSelector = requireNonNull(routingGroupSelector);
60-
this.statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
61-
this.extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
62+
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
63+
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
64+
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
65+
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
6266
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
6367
}
6468

@@ -96,9 +100,9 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
96100

97101
private Optional<String> getPreviousBackend(HttpServletRequest request)
98102
{
99-
String queryId = extractQueryIdIfPresent(request, statementPaths);
100-
if (!isNullOrEmpty(queryId)) {
101-
return Optional.of(routingManager.findBackendForQueryId(queryId));
103+
Optional<String> queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
104+
if (queryId.isPresent()) {
105+
return queryId.map(routingManager::findBackendForQueryId);
102106
}
103107
if (cookiesEnabled && request.getCookies() != null) {
104108
List<GatewayCookie> cookies = Arrays.stream(request.getCookies())

gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request)
118118
TrinoQueryProperties trinoQueryProperties = null;
119119
TrinoRequestUser trinoRequestUser = null;
120120
if (requestAnalyzerConfig.isAnalyzeRequest()) {
121-
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
121+
trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize());
122122
trinoRequestUser = trinoRequestUserProvider.getInstance(request);
123123
}
124124

gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ public String findRoutingGroup(HttpServletRequest request)
9696

9797
facts.put("request", request);
9898
if (requestAnalyzerConfig.isAnalyzeRequest()) {
99-
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
99+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(
100+
request,
101+
requestAnalyzerConfig.isClientsUseV2Format(),
102+
requestAnalyzerConfig.getMaxBodySize());
100103
TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request);
101104
facts.put("trinoQueryProperties", trinoQueryProperties);
102105
facts.put("trinoRequestUser", trinoRequestUser);

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
package io.trino.gateway.ha.router;
1515

1616
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonIgnore;
1718
import com.fasterxml.jackson.annotation.JsonProperty;
1819
import com.fasterxml.jackson.core.JsonGenerator;
1920
import com.fasterxml.jackson.databind.SerializerProvider;
2021
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
2122
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
23+
import com.google.common.collect.ImmutableList;
2224
import com.google.common.collect.ImmutableMap;
2325
import com.google.common.collect.ImmutableSet;
2426
import io.airlift.compress.zstd.ZstdDecompressor;
@@ -29,6 +31,7 @@
2931
import io.trino.sql.parser.SqlParser;
3032
import io.trino.sql.tree.AddColumn;
3133
import io.trino.sql.tree.Analyze;
34+
import io.trino.sql.tree.Call;
3235
import io.trino.sql.tree.CreateCatalog;
3336
import io.trino.sql.tree.CreateMaterializedView;
3437
import io.trino.sql.tree.CreateSchema;
@@ -40,6 +43,7 @@
4043
import io.trino.sql.tree.DropTable;
4144
import io.trino.sql.tree.Execute;
4245
import io.trino.sql.tree.ExecuteImmediate;
46+
import io.trino.sql.tree.Expression;
4347
import io.trino.sql.tree.Identifier;
4448
import io.trino.sql.tree.Node;
4549
import io.trino.sql.tree.NodeLocation;
@@ -57,6 +61,7 @@
5761
import io.trino.sql.tree.ShowSchemas;
5862
import io.trino.sql.tree.ShowTables;
5963
import io.trino.sql.tree.Statement;
64+
import io.trino.sql.tree.StringLiteral;
6065
import io.trino.sql.tree.Table;
6166
import io.trino.sql.tree.TableFunctionInvocation;
6267
import jakarta.servlet.http.HttpServletRequest;
@@ -73,6 +78,7 @@
7378
import java.util.Set;
7479
import java.util.stream.Collectors;
7580

81+
import static com.google.common.base.Preconditions.checkArgument;
7682
import static com.google.common.io.BaseEncoding.base64Url;
7783
import static io.airlift.json.JsonCodec.jsonCodec;
7884
import static java.lang.Math.toIntExact;
@@ -85,6 +91,7 @@ public class TrinoQueryProperties
8591
{
8692
private final Logger log = Logger.get(TrinoQueryProperties.class);
8793
private final boolean isClientsUseV2Format;
94+
private final int maxBodySize;
8895
private String body = "";
8996
private String queryType = "";
9097
private String resourceGroupQueryType = "";
@@ -96,6 +103,7 @@ public class TrinoQueryProperties
96103
private Set<String> catalogSchemas = ImmutableSet.of();
97104
private boolean isNewQuerySubmission;
98105
private Optional<String> errorMessage = Optional.empty();
106+
private Optional<String> queryId = Optional.empty();
99107

100108
public static final String TRINO_CATALOG_HEADER_NAME = "X-Trino-Catalog";
101109
public static final String TRINO_SCHEMA_HEADER_NAME = "X-Trino-Schema";
@@ -128,21 +136,24 @@ public TrinoQueryProperties(
128136
this.isNewQuerySubmission = isNewQuerySubmission;
129137
this.errorMessage = requireNonNullElse(errorMessage, Optional.empty());
130138
isClientsUseV2Format = false;
139+
maxBodySize = -1;
131140
}
132141

133-
public TrinoQueryProperties(HttpServletRequest request, RequestAnalyzerConfig config)
142+
public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize)
134143
{
135-
isClientsUseV2Format = config.isClientsUseV2Format();
144+
requireNonNull(request, "request is null");
145+
this.isClientsUseV2Format = isClientsUseV2Format;
146+
this.maxBodySize = maxBodySize;
136147

137148
defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME));
138149
defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME));
139150
if (request.getMethod().equals(HttpMethod.POST)) {
140151
isNewQuerySubmission = true;
141-
processRequestBody(request, config);
152+
processRequestBody(request);
142153
}
143154
}
144155

145-
private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfig config)
156+
private void processRequestBody(HttpServletRequest request)
146157
{
147158
try (BufferedReader reader = request.getReader()) {
148159
if (reader == null) {
@@ -153,11 +164,11 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi
153164

154165
Map<String, String> preparedStatements = getPreparedStatements(request);
155166
SqlParser parser = new SqlParser();
156-
reader.mark(config.getMaxBodySize());
157-
char[] buffer = new char[config.getMaxBodySize()];
158-
int nChars = reader.read(buffer, 0, config.getMaxBodySize());
167+
reader.mark(maxBodySize);
168+
char[] buffer = new char[maxBodySize];
169+
int nChars = reader.read(buffer, 0, maxBodySize);
159170
reader.reset();
160-
if (nChars == config.getMaxBodySize()) {
171+
if (nChars == maxBodySize) {
161172
log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected");
162173
return;
163174
//The body is truncated - there is a chance that it could still be syntactically valid SQL, for example if truncated on
@@ -199,7 +210,7 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
199210
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
200211
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
201212

202-
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
213+
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
203214
tables = tableBuilder.build();
204215
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
205216
catalogs = catalogBuilder.build();
@@ -260,7 +271,7 @@ private String decodePreparedStatementFromHeader(String headerValue)
260271
return new String(preparedStatement, UTF_8);
261272
}
262273

263-
private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
274+
private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
264275
ImmutableSet.Builder<String> catalogBuilder,
265276
ImmutableSet.Builder<String> schemaBuilder,
266277
ImmutableSet.Builder<String> catalogSchemaBuilder)
@@ -269,6 +280,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
269280
switch (node) {
270281
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
271282
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
283+
case Call call -> queryId = extractQueryIdFromCall(call);
272284
case CreateCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
273285
case CreateMaterializedView s -> tableBuilder.add(qualifyName(s.getName()));
274286
case CreateSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
@@ -342,10 +354,22 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
342354
}
343355

344356
for (Node child : node.getChildren()) {
345-
getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
357+
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
346358
}
347359
}
348360

361+
private Optional<String> extractQueryIdFromCall(Call call)
362+
throws RequestParsingException
363+
{
364+
QualifiedName callName = qualifyName(call.getName());
365+
if (callName.equals(QualifiedName.of("system", "runtime", "kill_query"))) {
366+
Expression argument = call.getArguments().getFirst().getValue();
367+
checkArgument(argument instanceof StringLiteral, "Unable to route kill_query procedures where the first argument is not a String Literal");
368+
return Optional.of(((StringLiteral) argument).getValue());
369+
}
370+
return Optional.empty();
371+
}
372+
349373
private void setCatalogAndSchemaNameFromSchemaQualifiedName(
350374
Optional<QualifiedName> schemaOptional,
351375
ImmutableSet.Builder<String> catalogBuilder,
@@ -381,15 +405,16 @@ private RequestParsingException unsetDefaultExceptionSupplier()
381405
return new RequestParsingException("Name not fully qualified");
382406
}
383407

384-
private QualifiedName qualifyName(QualifiedName table)
408+
private QualifiedName qualifyName(QualifiedName name)
385409
throws RequestParsingException
386410
{
387-
List<String> tableParts = table.getParts();
388-
return switch (tableParts.size()) {
389-
case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst());
390-
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1));
391-
case 3 -> QualifiedName.of(tableParts.getFirst(), tableParts.get(1), tableParts.get(2));
392-
default -> throw new RequestParsingException("Unexpected table name: " + table.getParts());
411+
List<String> nameParts = name.getParts();
412+
return switch (nameParts.size()) {
413+
case 1 ->
414+
QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst());
415+
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst(), nameParts.get(1));
416+
case 3 -> QualifiedName.of(nameParts.getFirst(), nameParts.get(1), nameParts.get(2));
417+
default -> throw new RequestParsingException("Unexpected qualified name: " + name.getParts());
393418
};
394419
}
395420

@@ -520,6 +545,12 @@ public Optional<String> getErrorMessage()
520545
return errorMessage;
521546
}
522547

548+
@JsonIgnore
549+
public Optional<String> getQueryId()
550+
{
551+
return queryId;
552+
}
553+
523554
public static class AlternateStatementRequestBodyFormat
524555
{
525556
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session

0 commit comments

Comments
 (0)