Skip to content

Commit 9f55135

Browse files
committed
Fix parsing failure on WITH clauses
1 parent cbfa643 commit 9f55135

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import io.trino.sql.tree.Node;
4545
import io.trino.sql.tree.NodeLocation;
4646
import io.trino.sql.tree.QualifiedName;
47+
import io.trino.sql.tree.Query;
4748
import io.trino.sql.tree.RenameMaterializedView;
4849
import io.trino.sql.tree.RenameSchema;
4950
import io.trino.sql.tree.RenameTable;
@@ -59,6 +60,7 @@
5960
import io.trino.sql.tree.Statement;
6061
import io.trino.sql.tree.Table;
6162
import io.trino.sql.tree.TableFunctionInvocation;
63+
import io.trino.sql.tree.WithQuery;
6264
import jakarta.servlet.http.HttpServletRequest;
6365
import jakarta.ws.rs.HttpMethod;
6466

@@ -68,6 +70,7 @@
6870
import java.util.ArrayList;
6971
import java.util.Arrays;
7072
import java.util.Enumeration;
73+
import java.util.HashSet;
7174
import java.util.List;
7275
import java.util.Map;
7376
import java.util.Optional;
@@ -198,10 +201,13 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
198201
ImmutableSet.Builder<String> catalogBuilder = ImmutableSet.builder();
199202
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
200203
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
204+
Set<QualifiedName> temporaryTables = new HashSet<>();
201205

202-
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
206+
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
203207
tables = tableBuilder.build();
208+
204209
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
210+
205211
catalogs = catalogBuilder.build();
206212
schemaBuilder.addAll(tables.stream().map(q -> q.getParts().get(1)).iterator());
207213
schemas = schemaBuilder.build();
@@ -263,7 +269,8 @@ private String decodePreparedStatementFromHeader(String headerValue)
263269
private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
264270
ImmutableSet.Builder<String> catalogBuilder,
265271
ImmutableSet.Builder<String> schemaBuilder,
266-
ImmutableSet.Builder<String> catalogSchemaBuilder)
272+
ImmutableSet.Builder<String> catalogSchemaBuilder,
273+
Set<QualifiedName> temporaryTables)
267274
throws RequestParsingException
268275
{
269276
switch (node) {
@@ -278,6 +285,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
278285
case DropCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
279286
case DropSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
280287
case DropTable s -> tableBuilder.add(qualifyName(s.getTableName()));
288+
case Query q -> q.getWith().ifPresent(with -> temporaryTables.addAll(with.getQueries().stream().map(WithQuery::getName).map(Identifier::getValue).map(QualifiedName::of).toList()));
281289
case RenameMaterializedView s -> {
282290
tableBuilder.add(qualifyName(s.getSource()));
283291
tableBuilder.add(qualifyName(s.getTarget()));
@@ -336,13 +344,18 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
336344
case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
337345
case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
338346
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
339-
case Table s -> tableBuilder.add(qualifyName(s.getName()));
347+
case Table s -> {
348+
// ignore temporary tables as they can have various table parts
349+
if (!temporaryTables.contains(s.getName())) {
350+
tableBuilder.add(qualifyName(s.getName()));
351+
}
352+
}
340353
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
341354
default -> {}
342355
}
343356

344357
for (Node child : node.getChildren()) {
345-
getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
358+
getNames(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
346359
}
347360
}
348361

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,33 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
404404
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
405405
}
406406

407+
@Test
408+
void testWithQueryNameExcluded()
409+
throws IOException
410+
{
411+
String query = """
412+
WITH dos AS (SELECT c1 from cat.schem.tbl1),
413+
uno as (SELECT c1 FROM dos)
414+
SELECT c1 FROM uno, dos
415+
""";
416+
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
417+
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
418+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
419+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);
420+
421+
TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(mockRequestWithDefaults, requestAnalyzerConfig);
422+
423+
Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
424+
assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
425+
426+
HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
427+
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
428+
429+
TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(mockRequestNoDefaults, requestAnalyzerConfig);
430+
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
431+
assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
432+
}
433+
407434
private HttpServletRequest prepareMockRequest()
408435
{
409436
HttpServletRequest mockRequest = mock(HttpServletRequest.class);

0 commit comments

Comments
 (0)