Skip to content

Commit 40593ae

Browse files
committed
Fix parsing failure on WITH clauses
1 parent cbfa643 commit 40593ae

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import io.trino.sql.tree.Statement;
6060
import io.trino.sql.tree.Table;
6161
import io.trino.sql.tree.TableFunctionInvocation;
62+
import io.trino.sql.tree.WithQuery;
6263
import jakarta.servlet.http.HttpServletRequest;
6364
import jakarta.ws.rs.HttpMethod;
6465

@@ -68,12 +69,14 @@
6869
import java.util.ArrayList;
6970
import java.util.Arrays;
7071
import java.util.Enumeration;
72+
import java.util.HashSet;
7173
import java.util.List;
7274
import java.util.Map;
7375
import java.util.Optional;
7476
import java.util.Set;
7577
import java.util.stream.Collectors;
7678

79+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
7780
import static com.google.common.io.BaseEncoding.base64Url;
7881
import static io.airlift.json.JsonCodec.jsonCodec;
7982
import static java.lang.Math.toIntExact;
@@ -90,6 +93,7 @@ public class TrinoQueryProperties
9093
private String queryType = "";
9194
private String resourceGroupQueryType = "";
9295
private Set<QualifiedName> tables = ImmutableSet.of();
96+
private final Set<QualifiedName> temporaryTables = new HashSet<>();
9397
private final Optional<String> defaultCatalog;
9498
private final Optional<String> defaultSchema;
9599
private Set<String> catalogs = ImmutableSet.of();
@@ -201,12 +205,17 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
201205

202206
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
203207
tables = tableBuilder.build();
204-
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
208+
209+
Set<QualifiedName> filteredTables = tables.stream()
210+
.filter(table -> !temporaryTables.contains(table))
211+
.collect(toImmutableSet());
212+
213+
catalogBuilder.addAll(filteredTables.stream().map(q -> q.getParts().getFirst()).iterator());
205214
catalogs = catalogBuilder.build();
206-
schemaBuilder.addAll(tables.stream().map(q -> q.getParts().get(1)).iterator());
215+
schemaBuilder.addAll(filteredTables.stream().map(q -> q.getParts().get(1)).iterator());
207216
schemas = schemaBuilder.build();
208217
catalogSchemaBuilder.addAll(
209-
tables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
218+
filteredTables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
210219
catalogSchemas = catalogSchemaBuilder.build();
211220
}
212221
catch (IOException e) {
@@ -336,8 +345,14 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
336345
case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
337346
case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
338347
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
339-
case Table s -> tableBuilder.add(qualifyName(s.getName()));
348+
case Table s -> {
349+
// ignore temporary tables as they can have various table parts
350+
if (!temporaryTables.contains(s.getName())) {
351+
tableBuilder.add(qualifyName(s.getName()));
352+
}
353+
}
340354
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
355+
case WithQuery withQuery -> temporaryTables.add(QualifiedName.of(withQuery.getName().getValue()));
341356
default -> {}
342357
}
343358

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,33 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
404404
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
405405
}
406406

407+
@Test
408+
void testWithQueryNameExcluded()
409+
throws IOException
410+
{
411+
String query = """
412+
WITH dos AS (SELECT c1 from cat.schem.tbl1),
413+
uno as (SELECT c1 FROM dos)
414+
SELECT c1 FROM uno, dos
415+
""";
416+
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
417+
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
418+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
419+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);
420+
421+
TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(mockRequestWithDefaults, requestAnalyzerConfig);
422+
423+
Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
424+
assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
425+
426+
HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
427+
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
428+
429+
TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(mockRequestNoDefaults, requestAnalyzerConfig);
430+
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
431+
assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
432+
}
433+
407434
private HttpServletRequest prepareMockRequest()
408435
{
409436
HttpServletRequest mockRequest = mock(HttpServletRequest.class);

0 commit comments

Comments
 (0)