Skip to content

Commit 6c24589

Browse files
committed
Extract query ID from all kill_query procedure variations
1 parent a45d249 commit 6c24589

File tree

8 files changed

+519
-76
lines changed

8 files changed

+519
-76
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,22 @@
1616
import com.google.common.base.Splitter;
1717
import com.google.common.io.CharStreams;
1818
import io.airlift.log.Logger;
19+
import io.trino.gateway.ha.router.SerializableExpression;
20+
import io.trino.gateway.ha.router.TrinoQueryProperties;
21+
import io.trino.sql.tree.Expression;
22+
import io.trino.sql.tree.StringLiteral;
1923
import jakarta.servlet.http.HttpServletRequest;
24+
import jakarta.ws.rs.HttpMethod;
2025

26+
import java.io.IOException;
2127
import java.io.InputStreamReader;
2228
import java.util.Base64;
2329
import java.util.List;
2430
import java.util.Optional;
2531
import java.util.regex.Matcher;
2632
import java.util.regex.Pattern;
2733

34+
import static com.google.common.base.Preconditions.checkArgument;
2835
import static com.google.common.base.Strings.isNullOrEmpty;
2936
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
3037
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
@@ -52,7 +59,6 @@ public final class ProxyUtils
5259
* capitalization.
5360
*/
5461
private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
55-
private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");
5662

5763
private ProxyUtils() {}
5864

@@ -91,47 +97,51 @@ public static String getQueryUser(String userHeader, String authorization)
9197
return parts.get(0);
9298
}
9399

94-
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
100+
public static Optional<String> extractQueryIdIfPresent(
101+
HttpServletRequest request,
102+
List<String> statementPaths,
103+
boolean requestAnalyserClientsUseV2Format,
104+
int requestAnalyserMaxBodySize)
95105
{
96106
String path = request.getRequestURI();
97107
String queryParams = request.getQueryString();
108+
if (!request.getMethod().equals(HttpMethod.POST)) {
109+
return extractQueryIdIfPresent(path, queryParams, statementPaths);
110+
}
111+
String queryText;
98112
try {
99-
String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
100-
if (!isNullOrEmpty(queryText)
101-
&& queryText.toLowerCase(ENGLISH).contains("system.runtime.kill_query")) {
102-
// extract and return the queryId
103-
String[] parts = queryText.split(",");
104-
for (String part : parts) {
105-
if (part.contains("query_id")) {
106-
Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
107-
if (matcher.find()) {
108-
String queryQuoted = matcher.group();
109-
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
110-
return queryQuoted.substring(1, queryQuoted.length() - 1);
111-
}
112-
}
113-
}
114-
}
115-
}
113+
queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8));
116114
}
117-
catch (Exception e) {
118-
log.error(e, "Error extracting query payload from request");
115+
catch (IOException e) {
116+
throw new RuntimeException("Error reading request body", e);
119117
}
120-
121-
return extractQueryIdIfPresent(path, queryParams, statementPaths);
118+
if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) {
119+
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
120+
if (trinoQueryProperties.procedureNameEquals("system.runtime.kill_query")) {
121+
SerializableExpression argument = trinoQueryProperties.getProcedureArguments().getFirst().getValue();
122+
checkArgument(argument.getOriginalClass().equals(StringLiteral.class), "Unable to route kill_query procedures where the first argument is not a String Literal");
123+
return Optional.of(argument.getValue());
124+
}
125+
}
126+
return Optional.empty();
122127
}
123128

124-
public static String extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
129+
public static Optional<String> extractQueryIdIfPresent(String path, String queryParams, List<String> statementPaths)
125130
{
126131
if (path == null) {
127-
return null;
132+
return Optional.empty();
128133
}
129-
String queryId = null;
130134
log.debug("Trying to extract query id from path [%s] or queryString [%s]", path, queryParams);
131135
// matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number,
132136
// and if custom paths are supplied using the statementPaths configuration, paths such as
133137
// /custom/statement/path/executing/query_id/nonce/sequence_number
134138
Optional<String> matchingStatementPath = statementPaths.stream().filter(path::startsWith).findAny();
139+
if (!isNullOrEmpty(queryParams)) {
140+
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
141+
if (matcher.matches()) {
142+
return Optional.of(matcher.group(1));
143+
}
144+
}
135145
if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) {
136146
path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), "");
137147
String[] tokens = path.split("/");
@@ -140,27 +150,20 @@ public static String extractQueryIdIfPresent(String path, String queryParams, Li
140150
|| tokens[1].equals("scheduled")
141151
|| tokens[1].equals("executing")
142152
|| tokens[1].equals("partialCancel")) {
143-
queryId = tokens[2];
153+
return Optional.of(tokens[2]);
144154
}
145155
else {
146-
queryId = tokens[1];
156+
return Optional.of(tokens[1]);
147157
}
148158
}
149159
}
150160
else if (path.startsWith(TRINO_UI_PATH)) {
151161
Matcher matcher = QUERY_ID_PATH_PATTERN.matcher(path);
152162
if (matcher.matches()) {
153-
queryId = matcher.group(1);
154-
}
155-
}
156-
if (!isNullOrEmpty(queryParams)) {
157-
Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
158-
if (matcher.matches()) {
159-
queryId = matcher.group(1);
163+
return Optional.of(matcher.group(1));
160164
}
161165
}
162-
log.debug("Query id in URL [%s]", queryId);
163-
return queryId;
166+
return Optional.empty();
164167
}
165168

166169
public static String buildUriWithNewBackend(String backendHost, HttpServletRequest request)

gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public class RoutingTargetHandler
4747
private final RoutingGroupSelector routingGroupSelector;
4848
private final List<String> statementPaths;
4949
private final List<Pattern> extraWhitelistPaths;
50+
private final boolean requestAnalyserClientsUseV2Format;
51+
private final int requestAnalyserMaxBodySize;
5052
private final boolean cookiesEnabled;
5153

5254
@Inject
@@ -57,8 +59,10 @@ public RoutingTargetHandler(
5759
{
5860
this.routingManager = requireNonNull(routingManager);
5961
this.routingGroupSelector = requireNonNull(routingGroupSelector);
60-
this.statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
61-
this.extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
62+
statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths());
63+
extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList());
64+
requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format();
65+
requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize();
6266
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
6367
}
6468

@@ -96,9 +100,9 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
96100

97101
private Optional<String> getPreviousBackend(HttpServletRequest request)
98102
{
99-
String queryId = extractQueryIdIfPresent(request, statementPaths);
100-
if (!isNullOrEmpty(queryId)) {
101-
return Optional.of(routingManager.findBackendForQueryId(queryId));
103+
Optional<String> queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize);
104+
if (queryId.isPresent()) {
105+
return queryId.map(routingManager::findBackendForQueryId);
102106
}
103107
if (cookiesEnabled && request.getCookies() != null) {
104108
List<GatewayCookie> cookies = Arrays.stream(request.getCookies())
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.router;
15+
16+
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonProperty;
18+
import io.trino.sql.tree.CallArgument;
19+
import io.trino.sql.tree.Expression;
20+
import io.trino.sql.tree.Identifier;
21+
22+
import java.util.Optional;
23+
24+
import static java.util.Objects.requireNonNull;
25+
26+
// CallArgument is final, so this class just replicates it. Location is not preserved, since it would require
27+
// additional complexity and is not meaningful in this context
28+
public class SerializableCallArgument
29+
{
30+
private final Optional<Identifier> name;
31+
private final SerializableExpression value;
32+
33+
@JsonCreator
34+
public SerializableCallArgument(
35+
@JsonProperty("name") Optional<String> name,
36+
@JsonProperty("value") SerializableExpression value)
37+
{
38+
this.name = requireNonNull(name, "name is null").map(Identifier::new);
39+
this.value = requireNonNull(value, "value is null");
40+
}
41+
42+
public SerializableCallArgument(CallArgument callArgument)
43+
{
44+
this.name = callArgument.getName();
45+
this.value = new SerializableExpression(callArgument.getValue());
46+
}
47+
48+
@JsonProperty
49+
public SerializableExpression getValue()
50+
{
51+
return value;
52+
}
53+
54+
@JsonProperty
55+
public Optional<String> getName()
56+
{
57+
return name.map(Identifier::getValue);
58+
}
59+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.router;
15+
16+
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonProperty;
18+
import com.google.common.collect.ImmutableList;
19+
import io.trino.sql.tree.Expression;
20+
import io.trino.sql.tree.Node;
21+
import io.trino.sql.tree.NodeLocation;
22+
import io.trino.sql.tree.StringLiteral;
23+
import net.minidev.json.annotate.JsonIgnore;
24+
25+
import java.util.List;
26+
import java.util.Objects;
27+
import java.util.Optional;
28+
29+
public class SerializableExpression
30+
extends Expression
31+
{
32+
private final String value;
33+
private final Class originalClass;
34+
35+
@JsonCreator
36+
public SerializableExpression(@JsonProperty("value") String value)
37+
{
38+
this(Optional.empty(), value);
39+
}
40+
41+
protected SerializableExpression(Optional<NodeLocation> location, String value)
42+
{
43+
super(location);
44+
this.value = value;
45+
this.originalClass = SerializableExpression.class;
46+
}
47+
48+
public SerializableExpression(Expression expression)
49+
{
50+
super(expression.getLocation());
51+
originalClass = expression.getClass();
52+
if (expression instanceof StringLiteral) {
53+
// special handling for this common case so that quotes do not need to be stripped
54+
value = ((StringLiteral) expression).getValue();
55+
} else {
56+
value = expression.toString();
57+
}
58+
}
59+
60+
@JsonProperty
61+
public String getValue()
62+
{
63+
return value;
64+
}
65+
66+
@JsonProperty
67+
public Class getOriginalClass()
68+
{
69+
return originalClass;
70+
}
71+
72+
@JsonIgnore
73+
@Override
74+
public List<? extends Node> getChildren()
75+
{
76+
return ImmutableList.of();
77+
}
78+
79+
@Override
80+
public int hashCode()
81+
{
82+
return value.hashCode();
83+
}
84+
85+
@Override
86+
public boolean equals(Object obj)
87+
{
88+
if (this == obj) {
89+
return true;
90+
}
91+
else if (obj != null && this.getClass() == obj.getClass()) {
92+
SerializableExpression that = (SerializableExpression) obj;
93+
return Objects.equals(this.value, that.value);
94+
}
95+
else {
96+
return false;
97+
}
98+
}
99+
}

0 commit comments

Comments
 (0)