From 2e3c49d60b3518660e69a0c41256892c60a95d35 Mon Sep 17 00:00:00 2001 From: Will Morrison Date: Wed, 16 Oct 2024 22:49:58 -0400 Subject: [PATCH 1/3] Use List for tables --- .../ha/router/TrinoQueryProperties.java | 6 +-- .../ha/router/TestTrinoQueryProperties.java | 37 ++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java index 272c21067..698eac9c7 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java @@ -66,7 +66,6 @@ import java.io.IOException; import java.net.URLDecoder; import java.util.ArrayList; -import java.util.Arrays; import java.util.Enumeration; import java.util.List; import java.util.Map; @@ -107,7 +106,7 @@ public TrinoQueryProperties( @JsonProperty("body") String body, @JsonProperty("queryType") String queryType, @JsonProperty("resourceGroupQueryType") String resourceGroupQueryType, - @JsonProperty("tables") String[] tables, + @JsonProperty("tables") List tables, @JsonProperty("defaultCatalog") Optional defaultCatalog, @JsonProperty("defaultSchema") Optional defaultSchema, @JsonProperty("catalogs") Set catalogs, @@ -119,7 +118,8 @@ public TrinoQueryProperties( this.body = requireNonNullElse(body, ""); this.queryType = requireNonNullElse(queryType, ""); this.resourceGroupQueryType = resourceGroupQueryType; - this.tables = Arrays.stream(requireNonNullElse(tables, new String[] {})).map(this::parseIdentifierStringToQualifiedName).collect(Collectors.toSet()); + List defaultTables = ImmutableList.of(); + this.tables = requireNonNullElse(tables, defaultTables).stream().map(this::parseIdentifierStringToQualifiedName).collect(Collectors.toSet()); this.defaultCatalog = requireNonNullElse(defaultCatalog, Optional.empty()); this.defaultSchema = requireNonNullElse(defaultSchema, Optional.empty()); this.catalogs = requireNonNullElse(catalogs, ImmutableSet.of()); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java index ab4f0f9e9..018dfd99f 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java @@ -13,6 +13,7 @@ */ package io.trino.gateway.ha.router; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.json.JsonCodec; import org.junit.jupiter.api.Test; @@ -31,7 +32,7 @@ void testJsonCreator() "SELECT c1 from c.s.t1", "SELECT", "SELECT", - new String[] {"c.s.t1"}, + ImmutableList.of("c.s.t1"), Optional.empty(), Optional.empty(), ImmutableSet.of("c"), @@ -56,4 +57,38 @@ void testJsonCreator() assertThat(deserializedTrinoQueryProperties.isQueryParsingSuccessful()).isEqualTo(trinoQueryProperties.isQueryParsingSuccessful()); assertThat(deserializedTrinoQueryProperties.getErrorMessage()).isEqualTo(trinoQueryProperties.getErrorMessage()); } + + @Test + void testJsonCreatorWithEmptyProperties() + { + JsonCodec codec = JsonCodec.jsonCodec(TrinoQueryProperties.class); + TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( + "SELECT c1 from c.s.t1", + "SELECT", + "SELECT", + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + ImmutableSet.of(), + ImmutableSet.of(), + true, + Optional.empty()); + + String trinoQueryPropertiesJson = codec.toJson(trinoQueryProperties); + TrinoQueryProperties deserializedTrinoQueryProperties = codec.fromJson(trinoQueryPropertiesJson); + + assertThat(deserializedTrinoQueryProperties.getBody()).isEqualTo(trinoQueryProperties.getBody()); + assertThat(deserializedTrinoQueryProperties.getQueryType()).isEqualTo(trinoQueryProperties.getQueryType()); + assertThat(deserializedTrinoQueryProperties.getResourceGroupQueryType()).isEqualTo(trinoQueryProperties.getResourceGroupQueryType()); + assertThat(deserializedTrinoQueryProperties.getTables()).isEqualTo(trinoQueryProperties.getTables()); + assertThat(deserializedTrinoQueryProperties.getDefaultCatalog()).isEqualTo(trinoQueryProperties.getDefaultCatalog()); + assertThat(deserializedTrinoQueryProperties.getDefaultSchema()).isEqualTo(trinoQueryProperties.getDefaultSchema()); + assertThat(deserializedTrinoQueryProperties.getSchemas()).isEqualTo(trinoQueryProperties.getSchemas()); + assertThat(deserializedTrinoQueryProperties.getCatalogs()).isEqualTo(trinoQueryProperties.getCatalogs()); + assertThat(deserializedTrinoQueryProperties.getCatalogSchemas()).isEqualTo(trinoQueryProperties.getCatalogSchemas()); + assertThat(deserializedTrinoQueryProperties.isNewQuerySubmission()).isEqualTo(trinoQueryProperties.isNewQuerySubmission()); + assertThat(deserializedTrinoQueryProperties.isQueryParsingSuccessful()).isEqualTo(trinoQueryProperties.isQueryParsingSuccessful()); + assertThat(deserializedTrinoQueryProperties.getErrorMessage()).isEqualTo(trinoQueryProperties.getErrorMessage()); + } } From 7c8c14013cfd4fdce7ce15407390a41294f8b0b5 Mon Sep 17 00:00:00 2001 From: Will Morrison Date: Mon, 29 Jul 2024 15:03:32 -0400 Subject: [PATCH 2/3] Extract query ID from all kill_query procedure variations Co-authored-by: Yuya Ebihara --- .../trino/gateway/ha/handler/ProxyUtils.java | 69 +++--- .../ha/handler/RoutingTargetHandler.java | 14 +- .../router/ExternalRoutingGroupSelector.java | 2 +- .../RuleReloadingRoutingGroupSelector.java | 5 +- .../ha/router/TrinoQueryProperties.java | 68 ++++-- .../TestQueryIdCachingProxyHandler.java | 219 ++++++++++++++++-- .../ha/router/TestRoutingGroupSelector.java | 11 +- .../TestRoutingGroupSelectorExternal.java | 5 +- .../ha/router/TestTrinoQueryProperties.java | 6 +- 9 files changed, 316 insertions(+), 83 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java index 3de5fac2e..13c7c6596 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java @@ -15,8 +15,11 @@ import com.google.common.io.CharStreams; import io.airlift.log.Logger; +import io.trino.gateway.ha.router.TrinoQueryProperties; import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.HttpMethod; +import java.io.IOException; import java.io.InputStreamReader; import java.util.List; import java.util.Optional; @@ -49,51 +52,50 @@ public final class ProxyUtils * capitalization. */ private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*"); - private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'"); private ProxyUtils() {} - public static String extractQueryIdIfPresent(HttpServletRequest request, List statementPaths) + public static Optional extractQueryIdIfPresent( + HttpServletRequest request, + List statementPaths, + boolean requestAnalyserClientsUseV2Format, + int requestAnalyserMaxBodySize) { String path = request.getRequestURI(); String queryParams = request.getQueryString(); + if (!request.getMethod().equals(HttpMethod.POST)) { + return extractQueryIdIfPresent(path, queryParams, statementPaths); + } + String queryText; try { - String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8)); - if (!isNullOrEmpty(queryText) - && queryText.toLowerCase(ENGLISH).contains("system.runtime.kill_query")) { - // extract and return the queryId - String[] parts = queryText.split(","); - for (String part : parts) { - if (part.contains("query_id")) { - Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part); - if (matcher.find()) { - String queryQuoted = matcher.group(); - if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) { - return queryQuoted.substring(1, queryQuoted.length() - 1); - } - } - } - } - } + queryText = CharStreams.toString(new InputStreamReader(request.getInputStream(), UTF_8)); } - catch (Exception e) { - log.error(e, "Error extracting query payload from request"); + catch (IOException e) { + throw new RuntimeException("Error reading request body", e); } - - return extractQueryIdIfPresent(path, queryParams, statementPaths); + if (!isNullOrEmpty(queryText) && queryText.toLowerCase(ENGLISH).contains("kill_query")) { + TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize); + return trinoQueryProperties.getQueryId(); + } + return Optional.empty(); } - public static String extractQueryIdIfPresent(String path, String queryParams, List statementPaths) + public static Optional extractQueryIdIfPresent(String path, String queryParams, List statementPaths) { if (path == null) { - return null; + return Optional.empty(); } - String queryId = null; log.debug("Trying to extract query id from path [%s] or queryString [%s]", path, queryParams); // matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number, // and if custom paths are supplied using the statementPaths configuration, paths such as // /custom/statement/path/executing/query_id/nonce/sequence_number Optional matchingStatementPath = statementPaths.stream().filter(path::startsWith).findAny(); + if (!isNullOrEmpty(queryParams)) { + Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams); + if (matcher.matches()) { + return Optional.of(matcher.group(1)); + } + } if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) { path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), ""); String[] tokens = path.split("/"); @@ -102,27 +104,20 @@ public static String extractQueryIdIfPresent(String path, String queryParams, Li || tokens[1].equals("scheduled") || tokens[1].equals("executing") || tokens[1].equals("partialCancel")) { - queryId = tokens[2]; + return Optional.of(tokens[2]); } else { - queryId = tokens[1]; + return Optional.of(tokens[1]); } } } else if (path.startsWith(TRINO_UI_PATH)) { Matcher matcher = QUERY_ID_PATH_PATTERN.matcher(path); if (matcher.matches()) { - queryId = matcher.group(1); - } - } - if (!isNullOrEmpty(queryParams)) { - Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams); - if (matcher.matches()) { - queryId = matcher.group(1); + return Optional.of(matcher.group(1)); } } - log.debug("Query id in URL [%s]", queryId); - return queryId; + return Optional.empty(); } public static String buildUriWithNewBackend(String backendHost, HttpServletRequest request) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java index 1f3b12c90..4b3d3e808 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/RoutingTargetHandler.java @@ -47,6 +47,8 @@ public class RoutingTargetHandler private final RoutingGroupSelector routingGroupSelector; private final List statementPaths; private final List extraWhitelistPaths; + private final boolean requestAnalyserClientsUseV2Format; + private final int requestAnalyserMaxBodySize; private final boolean cookiesEnabled; @Inject @@ -57,8 +59,10 @@ public RoutingTargetHandler( { this.routingManager = requireNonNull(routingManager); this.routingGroupSelector = requireNonNull(routingGroupSelector); - this.statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths()); - this.extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList()); + statementPaths = requireNonNull(haGatewayConfiguration.getStatementPaths()); + extraWhitelistPaths = requireNonNull(haGatewayConfiguration.getExtraWhitelistPaths()).stream().map(Pattern::compile).collect(toImmutableList()); + requestAnalyserClientsUseV2Format = haGatewayConfiguration.getRequestAnalyzerConfig().isClientsUseV2Format(); + requestAnalyserMaxBodySize = haGatewayConfiguration.getRequestAnalyzerConfig().getMaxBodySize(); cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled(); } @@ -96,9 +100,9 @@ private String getBackendFromRoutingGroup(HttpServletRequest request) private Optional getPreviousBackend(HttpServletRequest request) { - String queryId = extractQueryIdIfPresent(request, statementPaths); - if (!isNullOrEmpty(queryId)) { - return Optional.of(routingManager.findBackendForQueryId(queryId)); + Optional queryId = extractQueryIdIfPresent(request, statementPaths, requestAnalyserClientsUseV2Format, requestAnalyserMaxBodySize); + if (queryId.isPresent()) { + return queryId.map(routingManager::findBackendForQueryId); } if (cookiesEnabled && request.getCookies() != null) { List cookies = Arrays.stream(request.getCookies()) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java index ba68280b3..c083f3122 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/ExternalRoutingGroupSelector.java @@ -118,7 +118,7 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request) TrinoQueryProperties trinoQueryProperties = null; TrinoRequestUser trinoRequestUser = null; if (requestAnalyzerConfig.isAnalyzeRequest()) { - trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig); + trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig.isClientsUseV2Format(), requestAnalyzerConfig.getMaxBodySize()); trinoRequestUser = trinoRequestUserProvider.getInstance(request); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java index 1056972ce..4be89f39c 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RuleReloadingRoutingGroupSelector.java @@ -96,7 +96,10 @@ public String findRoutingGroup(HttpServletRequest request) facts.put("request", request); if (requestAnalyzerConfig.isAnalyzeRequest()) { - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig); + TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( + request, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); TrinoRequestUser trinoRequestUser = trinoRequestUserProvider.getInstance(request); facts.put("trinoQueryProperties", trinoQueryProperties); facts.put("trinoRequestUser", trinoRequestUser); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java index 698eac9c7..791d92bbd 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java @@ -14,21 +14,23 @@ package io.trino.gateway.ha.router; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.compress.zstd.ZstdDecompressor; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; -import io.trino.gateway.ha.config.RequestAnalyzerConfig; import io.trino.sql.parser.ParsingException; import io.trino.sql.parser.SqlParser; import io.trino.sql.tree.AddColumn; import io.trino.sql.tree.Analyze; +import io.trino.sql.tree.Call; import io.trino.sql.tree.CreateCatalog; import io.trino.sql.tree.CreateMaterializedView; import io.trino.sql.tree.CreateSchema; @@ -40,6 +42,7 @@ import io.trino.sql.tree.DropTable; import io.trino.sql.tree.Execute; import io.trino.sql.tree.ExecuteImmediate; +import io.trino.sql.tree.Expression; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeLocation; @@ -57,6 +60,7 @@ import io.trino.sql.tree.ShowSchemas; import io.trino.sql.tree.ShowTables; import io.trino.sql.tree.Statement; +import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableFunctionInvocation; import jakarta.servlet.http.HttpServletRequest; @@ -73,6 +77,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.io.BaseEncoding.base64Url; import static io.airlift.json.JsonCodec.jsonCodec; import static java.lang.Math.toIntExact; @@ -85,6 +90,7 @@ public class TrinoQueryProperties { private final Logger log = Logger.get(TrinoQueryProperties.class); private final boolean isClientsUseV2Format; + private final int maxBodySize; private String body = ""; private String queryType = ""; private String resourceGroupQueryType = ""; @@ -96,6 +102,7 @@ public class TrinoQueryProperties private Set catalogSchemas = ImmutableSet.of(); private boolean isNewQuerySubmission; private Optional errorMessage = Optional.empty(); + private Optional queryId = Optional.empty(); public static final String TRINO_CATALOG_HEADER_NAME = "X-Trino-Catalog"; public static final String TRINO_SCHEMA_HEADER_NAME = "X-Trino-Schema"; @@ -128,21 +135,24 @@ public TrinoQueryProperties( this.isNewQuerySubmission = isNewQuerySubmission; this.errorMessage = requireNonNullElse(errorMessage, Optional.empty()); isClientsUseV2Format = false; + maxBodySize = -1; } - public TrinoQueryProperties(HttpServletRequest request, RequestAnalyzerConfig config) + public TrinoQueryProperties(HttpServletRequest request, boolean isClientsUseV2Format, int maxBodySize) { - isClientsUseV2Format = config.isClientsUseV2Format(); + requireNonNull(request, "request is null"); + this.isClientsUseV2Format = isClientsUseV2Format; + this.maxBodySize = maxBodySize; defaultCatalog = Optional.ofNullable(request.getHeader(TRINO_CATALOG_HEADER_NAME)); defaultSchema = Optional.ofNullable(request.getHeader(TRINO_SCHEMA_HEADER_NAME)); if (request.getMethod().equals(HttpMethod.POST)) { isNewQuerySubmission = true; - processRequestBody(request, config); + processRequestBody(request); } } - private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfig config) + private void processRequestBody(HttpServletRequest request) { try (BufferedReader reader = request.getReader()) { if (reader == null) { @@ -153,11 +163,11 @@ private void processRequestBody(HttpServletRequest request, RequestAnalyzerConfi Map preparedStatements = getPreparedStatements(request); SqlParser parser = new SqlParser(); - reader.mark(config.getMaxBodySize()); - char[] buffer = new char[config.getMaxBodySize()]; - int nChars = reader.read(buffer, 0, config.getMaxBodySize()); + reader.mark(maxBodySize); + char[] buffer = new char[maxBodySize]; + int nChars = reader.read(buffer, 0, maxBodySize); reader.reset(); - if (nChars == config.getMaxBodySize()) { + if (nChars == maxBodySize) { log.warn("Query length greater or equal to requestAnalyzerConfig.maxBodySize detected"); return; //The body is truncated - there is a chance that it could still be syntactically valid SQL, for example if truncated on @@ -199,7 +209,7 @@ else if (statement instanceof ExecuteImmediate executeImmediate) { ImmutableSet.Builder schemaBuilder = ImmutableSet.builder(); ImmutableSet.Builder catalogSchemaBuilder = ImmutableSet.builder(); - getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); + visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); tables = tableBuilder.build(); catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator()); catalogs = catalogBuilder.build(); @@ -260,7 +270,7 @@ private String decodePreparedStatementFromHeader(String headerValue) return new String(preparedStatement, UTF_8); } - private void getNames(Node node, ImmutableSet.Builder tableBuilder, + private void visitNode(Node node, ImmutableSet.Builder tableBuilder, ImmutableSet.Builder catalogBuilder, ImmutableSet.Builder schemaBuilder, ImmutableSet.Builder catalogSchemaBuilder) @@ -269,6 +279,7 @@ private void getNames(Node node, ImmutableSet.Builder tableBuilde switch (node) { case AddColumn s -> tableBuilder.add(qualifyName(s.getName())); case Analyze s -> tableBuilder.add(qualifyName(s.getTableName())); + case Call call -> queryId = extractQueryIdFromCall(call); case CreateCatalog s -> catalogBuilder.add(s.getCatalogName().getValue()); case CreateMaterializedView s -> tableBuilder.add(qualifyName(s.getName())); case CreateSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder); @@ -342,10 +353,22 @@ private void getNames(Node node, ImmutableSet.Builder tableBuilde } for (Node child : node.getChildren()) { - getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); + visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder); } } + private Optional extractQueryIdFromCall(Call call) + throws RequestParsingException + { + QualifiedName callName = qualifyName(call.getName()); + if (callName.equals(QualifiedName.of("system", "runtime", "kill_query"))) { + Expression argument = call.getArguments().getFirst().getValue(); + checkArgument(argument instanceof StringLiteral, "Unable to route kill_query procedures where the first argument is not a String Literal"); + return Optional.of(((StringLiteral) argument).getValue()); + } + return Optional.empty(); + } + private void setCatalogAndSchemaNameFromSchemaQualifiedName( Optional schemaOptional, ImmutableSet.Builder catalogBuilder, @@ -381,15 +404,16 @@ private RequestParsingException unsetDefaultExceptionSupplier() return new RequestParsingException("Name not fully qualified"); } - private QualifiedName qualifyName(QualifiedName table) + private QualifiedName qualifyName(QualifiedName name) throws RequestParsingException { - List tableParts = table.getParts(); - return switch (tableParts.size()) { - case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst()); - case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1)); - case 3 -> QualifiedName.of(tableParts.getFirst(), tableParts.get(1), tableParts.get(2)); - default -> throw new RequestParsingException("Unexpected table name: " + table.getParts()); + List nameParts = name.getParts(); + return switch (nameParts.size()) { + case 1 -> + QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst()); + case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), nameParts.getFirst(), nameParts.get(1)); + case 3 -> QualifiedName.of(nameParts.getFirst(), nameParts.get(1), nameParts.get(2)); + default -> throw new RequestParsingException("Unexpected qualified name: " + name.getParts()); }; } @@ -520,6 +544,12 @@ public Optional getErrorMessage() return errorMessage; } + @JsonIgnore + public Optional getQueryId() + { + return queryId; + } + public static class AlternateStatementRequestBodyFormat { // Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java index 1f3f4543e..b5223d2c9 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java @@ -14,15 +14,27 @@ package io.trino.gateway.ha.handler; import com.google.common.collect.ImmutableList; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.HttpMethod; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.mockito.Mockito; +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.StringReader; import java.util.List; +import java.util.Optional; import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.when; @TestInstance(Lifecycle.PER_CLASS) final class TestQueryIdCachingProxyHandler @@ -33,31 +45,212 @@ void testExtractQueryIdFromUrl() { List statementPaths = ImmutableList.of("/v1/statement", "/custom/api/statement"); assertThat(extractQueryIdIfPresent("/v1/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/custom/api/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/v1/statement/queued/20200416_160256_03078_6b4yt/y0d7620a6941e78d3950798a1085383234258a566/1", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/killed", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/preempted", null, statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/v1/query/20200416_160256_03078_6b4yt", "pretty", statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/troubleshooting", "queryId=20200416_160256_03078_6b4yt", statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/query.html", "20200416_160256_03078_6b4yt", statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/login", "redirect=%2Fui%2Fapi%2Fquery%2F20200416_160256_03078_6b4yt", statementPaths)) - .isEqualTo("20200416_160256_03078_6b4yt"); + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/ui/api/query/myOtherThing", null, statementPaths)) - .isNull(); + .isEmpty(); assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_blah", "bogus_fictional_param", statementPaths)) - .isNull(); + .isEmpty(); assertThat(extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a", statementPaths)) - .isNull(); + .isEmpty(); + } + + @Test + void testQueryIdFromKill() + throws IOException + { + assertThat(extractQueryId(request("CALL system.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL system.runtime.kill_query(Query_id => '20200416_160256_03078_6b4yt', Message => 'If he dies, he dies')"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies')", "system", "runtime"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')", "system"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', 'kill_query(''20200416_160256_03078_7n5uy'')')"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt', '20200416_160256_03078_7n5uy')"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL system.runtime.kill_query(query_id=>'20200416_160256_03078_6b4yt')"))).hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL system.runtime.kill_query('20200416_160256_03078_6b4yt')"))).hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("CALL kill_query('20200416_160256_03078_6b4yt')", "system", "runtime"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request("call Kill_Query('20200416_160256_03078_6b4yt')", "system", "runtime"))) + .hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request( + "SELECT * FROM postgres.query_logs.queries WHERE sql LIKE '%kill_query(''20200416_160256%' ", + "system", + "runtime"))).isEmpty(); + + assertThat(extractQueryId(request( + "SELECT * FROM postgres.query_logs.queries WHERE sql LIKE '%kill_query(''20200416_160256_03078_6b4yt' ", + "system", + "runtime"))).isEmpty(); + + assertThat(extractQueryId(request( + "SELECT * FROM postgres.query_logs.queries WHERE sql LIKE 'CALL kill_query(_20200416_160256_03078_6b4yt_)' ", + "system", + "runtime"))).isEmpty(); + + assertThat(extractQueryId(request(""" + --CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies') + SELECT 1 + """, + "system", + "runtime"))).isEmpty(); + + assertThat(extractQueryId(request(""" + /* + CALL kill_query('20200416_160256_03078_6b4yt', 'If he dies, he dies') + */ + SELECT 1 + """, + "system", + "runtime"))).isEmpty(); + + assertThat(extractQueryId(request(""" + CALL KILL_QUERY('20200416_160256_03078_6b4yt', 'If he dies, he dies') + """, + "system", + "runtime"))).hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request(""" + CALL KILL_QUERY ('20200416_160256_03078_6b4yt', 'If he dies, he dies') + """, + "system", + "runtime"))).hasValue("20200416_160256_03078_6b4yt"); + + assertThat(extractQueryId(request(""" + CALL + KILL_QUERY + ( + -- this is a comment + '20200416_160256_03078_6b4yt' --this is a trailing comment + , + /* + this is + a multiline comment + */ + 'If he dies, he dies + ') + """, + "system", + "runtime"))).hasValue("20200416_160256_03078_6b4yt"); + + assertThatThrownBy(() -> extractQueryId(request(""" + CALL KILL_QUERY (lower('20200416_160256_03078_6b4yt'), 'If he dies, he dies') + """, + "system", + "runtime"))).isInstanceOf(IllegalArgumentException.class); + + assertThat(extractQueryId(request("CALL notsystem.runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')"))).isEmpty(); + + assertThat(extractQueryId(request("CALL runtime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')", "notsystem"))) + .isEmpty(); + + assertThat(extractQueryId(request("CALL notruntime.kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')", "system"))) + .isEmpty(); + + assertThat(extractQueryId(request( + "CALL kill_query(query_id => '20200416_160256_03078_6b4yt', message => 'If he dies, he dies')", + "system", + "notruntime"))) + .isEmpty(); + } + + private static Optional extractQueryId(HttpServletRequest request) + { + return extractQueryIdIfPresent(request, ImmutableList.of(), false, 1_000_000); + } + + private static HttpServletRequest request(String query, String defaultCatalog) + throws IOException + { + // Warning - this is not a fully featured mock of the behavior of HttpServlet with respect to headers. For example, + // getHeaderNames will return an empty list, and getHeader is not fully case-insensitive. This is only intended to be + // a minimal mock for this test. + HttpServletRequest request = request(query); + when(request.getHeader("X-Trino-Catalog")).thenReturn(defaultCatalog); + when(request.getHeader("X-trino-catalog")).thenReturn(defaultCatalog); + return request; + } + + private static HttpServletRequest request(String query, String defaultCatalog, String defaultSchema) + throws IOException + { + HttpServletRequest request = request(query); + when(request.getHeader("X-Trino-Catalog")).thenReturn(defaultCatalog); + when(request.getHeader("X-trino-catalog")).thenReturn(defaultCatalog); + when(request.getHeader("X-Trino-Schema")).thenReturn(defaultSchema); + when(request.getHeader("X-trino-schema")).thenReturn(defaultSchema); + return request; + } + + private static HttpServletRequest request(String query) + throws IOException + { + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(query.getBytes(UTF_8)); + when(request.getMethod()).thenReturn(HttpMethod.POST); + when(request.getInputStream()).thenReturn(new ServletInputStream() + { + @Override + public boolean isFinished() + { + return byteArrayInputStream.available() > 0; + } + + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) + {} + + @Override + public int read() + throws IOException + { + return byteArrayInputStream.read(); + } + }); + + when(request.getReader()).thenReturn(new BufferedReader(new StringReader(query))); + + when(request.getQueryString()).thenReturn(""); + + return request; } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java index 7262ec58c..1066d19a7 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java @@ -397,7 +397,10 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set catalogs, when(mockRequest.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG); when(mockRequest.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA); - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(mockRequest, requestAnalyzerConfig); + TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( + mockRequest, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); assertThat(trinoQueryProperties.getTables()).isEqualTo(tables); assertThat(trinoQueryProperties.getSchemas()).isEqualTo(schemas); @@ -418,8 +421,10 @@ void testLongQuery() BufferedReader bufferedReader = new BufferedReader(new FileReader("src/test/resources/wide_select.sql", UTF_8)); HttpServletRequest mockRequest = prepareMockRequest(); when(mockRequest.getReader()).thenReturn(bufferedReader); - TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(mockRequest, requestAnalyzerConfig); - + TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( + mockRequest, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); assertThat(trinoQueryProperties.tablesContains("kat.schem.widetable")).isTrue(); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java index 93df20252..a1c90179d 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelectorExternal.java @@ -224,7 +224,10 @@ private RoutingGroupExternalBody createRequestBody(HttpServletRequest request) TrinoQueryProperties trinoQueryProperties = null; TrinoRequestUser trinoRequestUser = null; if (requestAnalyzerConfig.isAnalyzeRequest()) { - trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig); + trinoQueryProperties = new TrinoQueryProperties( + request, + requestAnalyzerConfig.isClientsUseV2Format(), + requestAnalyzerConfig.getMaxBodySize()); trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(request); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java index 018dfd99f..769b6d8f9 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoQueryProperties.java @@ -63,9 +63,9 @@ void testJsonCreatorWithEmptyProperties() { JsonCodec codec = JsonCodec.jsonCodec(TrinoQueryProperties.class); TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties( - "SELECT c1 from c.s.t1", - "SELECT", - "SELECT", + "", + "", + "", ImmutableList.of(), Optional.empty(), Optional.empty(), From b27491642bfecadd7499176284863d1fa24fa9bb Mon Sep 17 00:00:00 2001 From: Will Morrison Date: Thu, 26 Sep 2024 17:00:50 -0400 Subject: [PATCH 3/3] Update maxBodySize type in RequestAnalyzerConfig to match usage in TrinoQueryProperties --- .../io/trino/gateway/ha/config/RequestAnalyzerConfig.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RequestAnalyzerConfig.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RequestAnalyzerConfig.java index 80a8a7d2d..2277d129b 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/RequestAnalyzerConfig.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/RequestAnalyzerConfig.java @@ -17,7 +17,7 @@ public class RequestAnalyzerConfig { - private Integer maxBodySize = 1_000_000; + private int maxBodySize = 1_000_000; private boolean isClientsUseV2Format; private String tokenUserField = "email"; @@ -26,13 +26,13 @@ public class RequestAnalyzerConfig public RequestAnalyzerConfig() {} - public Integer getMaxBodySize() + public int getMaxBodySize() { return maxBodySize; } @Max(Integer.MAX_VALUE) - public void setMaxBodySize(Integer maxBodySize) + public void setMaxBodySize(int maxBodySize) { this.maxBodySize = maxBodySize; }