Skip to content

Commit 998379d

Browse files
committed
Fix parsing failure on WITH clauses
1 parent f05cea4 commit 998379d

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import io.trino.sql.tree.Statement;
6060
import io.trino.sql.tree.Table;
6161
import io.trino.sql.tree.TableFunctionInvocation;
62+
import io.trino.sql.tree.WithQuery;
6263
import jakarta.servlet.http.HttpServletRequest;
6364
import jakarta.ws.rs.HttpMethod;
6465

@@ -68,12 +69,14 @@
6869
import java.util.ArrayList;
6970
import java.util.Arrays;
7071
import java.util.Enumeration;
72+
import java.util.HashSet;
7173
import java.util.List;
7274
import java.util.Map;
7375
import java.util.Optional;
7476
import java.util.Set;
7577
import java.util.stream.Collectors;
7678

79+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
7780
import static com.google.common.io.BaseEncoding.base64Url;
7881
import static io.airlift.json.JsonCodec.jsonCodec;
7982
import static java.lang.Math.toIntExact;
@@ -90,6 +93,7 @@ public class TrinoQueryProperties
9093
private String queryType = "";
9194
private String resourceGroupQueryType = "";
9295
private Set<QualifiedName> tables = ImmutableSet.of();
96+
private final Set<QualifiedName> temporaryTables = new HashSet<>();
9397
private final Optional<String> defaultCatalog;
9498
private final Optional<String> defaultSchema;
9599
private Set<String> catalogs = ImmutableSet.of();
@@ -201,12 +205,17 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
201205

202206
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
203207
tables = tableBuilder.build();
204-
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
208+
209+
Set<QualifiedName> filteredTables = tables.stream()
210+
.filter(table -> !temporaryTables.contains(table))
211+
.collect(toImmutableSet());
212+
213+
catalogBuilder.addAll(filteredTables.stream().map(q -> q.getParts().getFirst()).iterator());
205214
catalogs = catalogBuilder.build();
206-
schemaBuilder.addAll(tables.stream().map(q -> q.getParts().get(1)).iterator());
215+
schemaBuilder.addAll(filteredTables.stream().map(q -> q.getParts().get(1)).iterator());
207216
schemas = schemaBuilder.build();
208217
catalogSchemaBuilder.addAll(
209-
tables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
218+
filteredTables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
210219
catalogSchemas = catalogSchemaBuilder.build();
211220
}
212221
catch (IOException e) {
@@ -338,6 +347,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
338347
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
339348
case Table s -> tableBuilder.add(qualifyName(s.getName()));
340349
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
350+
case WithQuery withQuery -> temporaryTables.add(QualifiedName.of(withQuery.getName().getValue()));
341351
default -> {}
342352
}
343353

@@ -385,6 +395,12 @@ private QualifiedName qualifyName(QualifiedName table)
385395
throws RequestParsingException
386396
{
387397
List<String> tableParts = table.getParts();
398+
399+
// ignore temporary tables created by WITH clause as they can have various table parts
400+
if (temporaryTables.contains(table)) {
401+
return table;
402+
}
403+
388404
return switch (tableParts.size()) {
389405
case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst());
390406
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1));

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,41 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
404404
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
405405
}
406406

407+
@Test
408+
void testWithQueryNameExcluded()
409+
throws IOException
410+
{
411+
String query = """
412+
WITH dos AS (SELECT c1 from cat.schem.tbl1),
413+
uno as (SELECT c1 FROM dos)
414+
SELECT c1 FROM uno, dos
415+
""";
416+
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
417+
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
418+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
419+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);
420+
421+
TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(mockRequestWithDefaults, requestAnalyzerConfig);
422+
423+
Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
424+
assertThat(tablesWithDefaults).containsExactlyInAnyOrder(
425+
QualifiedName.of("uno"),
426+
QualifiedName.of("dos"),
427+
QualifiedName.of("cat", "schem", "tbl1")
428+
);
429+
430+
HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
431+
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
432+
433+
TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(mockRequestNoDefaults, requestAnalyzerConfig);
434+
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
435+
assertThat(tablesNoDefaults).containsExactlyInAnyOrder(
436+
QualifiedName.of("uno"),
437+
QualifiedName.of("dos"),
438+
QualifiedName.of("cat", "schem", "tbl1")
439+
);
440+
}
441+
407442
private HttpServletRequest prepareMockRequest()
408443
{
409444
HttpServletRequest mockRequest = mock(HttpServletRequest.class);

0 commit comments

Comments
 (0)