Skip to content

Commit 5861e19

Browse files
committed
Fix parsing failure on WITH clauses
1 parent 2add137 commit 5861e19

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import io.trino.sql.tree.Node;
4848
import io.trino.sql.tree.NodeLocation;
4949
import io.trino.sql.tree.QualifiedName;
50+
import io.trino.sql.tree.Query;
5051
import io.trino.sql.tree.RenameMaterializedView;
5152
import io.trino.sql.tree.RenameSchema;
5253
import io.trino.sql.tree.RenameTable;
@@ -63,6 +64,7 @@
6364
import io.trino.sql.tree.StringLiteral;
6465
import io.trino.sql.tree.Table;
6566
import io.trino.sql.tree.TableFunctionInvocation;
67+
import io.trino.sql.tree.WithQuery;
6668
import jakarta.servlet.http.HttpServletRequest;
6769
import jakarta.ws.rs.HttpMethod;
6870

@@ -71,6 +73,7 @@
7173
import java.net.URLDecoder;
7274
import java.util.ArrayList;
7375
import java.util.Enumeration;
76+
import java.util.HashSet;
7477
import java.util.List;
7578
import java.util.Map;
7679
import java.util.Optional;
@@ -208,8 +211,9 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
208211
ImmutableSet.Builder<String> catalogBuilder = ImmutableSet.builder();
209212
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
210213
ImmutableSet.Builder<String> catalogSchemaBuilder = ImmutableSet.builder();
214+
Set<QualifiedName> temporaryTables = new HashSet<>();
211215

212-
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
216+
visitNode(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
213217
tables = tableBuilder.build();
214218
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
215219
catalogs = catalogBuilder.build();
@@ -273,7 +277,8 @@ private String decodePreparedStatementFromHeader(String headerValue)
273277
private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuilder,
274278
ImmutableSet.Builder<String> catalogBuilder,
275279
ImmutableSet.Builder<String> schemaBuilder,
276-
ImmutableSet.Builder<String> catalogSchemaBuilder)
280+
ImmutableSet.Builder<String> catalogSchemaBuilder,
281+
Set<QualifiedName> temporaryTables)
277282
throws RequestParsingException
278283
{
279284
switch (node) {
@@ -289,6 +294,7 @@ private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuild
289294
case DropCatalog s -> catalogBuilder.add(s.getCatalogName().getValue());
290295
case DropSchema s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSchemaName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
291296
case DropTable s -> tableBuilder.add(qualifyName(s.getTableName()));
297+
case Query q -> q.getWith().ifPresent(with -> temporaryTables.addAll(with.getQueries().stream().map(WithQuery::getName).map(Identifier::getValue).map(QualifiedName::of).toList()));
292298
case RenameMaterializedView s -> {
293299
tableBuilder.add(qualifyName(s.getSource()));
294300
tableBuilder.add(qualifyName(s.getTarget()));
@@ -347,13 +353,18 @@ private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuild
347353
case SetSchemaAuthorization s -> setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(s.getSource()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
348354
case SetTableAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
349355
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
350-
case Table s -> tableBuilder.add(qualifyName(s.getName()));
356+
case Table s -> {
357+
// ignore temporary tables as they can have various table parts
358+
if (!temporaryTables.contains(s.getName())) {
359+
tableBuilder.add(qualifyName(s.getName()));
360+
}
361+
}
351362
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
352363
default -> {}
353364
}
354365

355366
for (Node child : node.getChildren()) {
356-
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
367+
visitNode(child, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder, temporaryTables);
357368
}
358369
}
359370

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,38 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
407407
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
408408
}
409409

410+
@Test
411+
void testWithQueryNameExcluded()
412+
throws IOException
413+
{
414+
String query = """
415+
WITH dos AS (SELECT c1 from cat.schem.tbl1),
416+
uno as (SELECT c1 FROM dos)
417+
SELECT c1 FROM uno, dos
418+
""";
419+
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
420+
when(mockRequestWithDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
421+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
422+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);
423+
424+
TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(
425+
mockRequestWithDefaults,
426+
requestAnalyzerConfig.isClientsUseV2Format(),
427+
requestAnalyzerConfig.getMaxBodySize());
428+
Set<QualifiedName> tablesWithDefaults = trinoQueryPropertiesWithDefaults.getTables();
429+
assertThat(tablesWithDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
430+
431+
HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
432+
when(mockRequestNoDefaults.getReader()).thenReturn(new BufferedReader(new StringReader(query)));
433+
434+
TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(
435+
mockRequestNoDefaults,
436+
requestAnalyzerConfig.isClientsUseV2Format(),
437+
requestAnalyzerConfig.getMaxBodySize());
438+
Set<QualifiedName> tablesNoDefaults = trinoQueryPropertiesNoDefaults.getTables();
439+
assertThat(tablesNoDefaults).containsExactly(QualifiedName.of("cat", "schem", "tbl1"));
440+
}
441+
410442
private HttpServletRequest prepareMockRequest()
411443
{
412444
HttpServletRequest mockRequest = mock(HttpServletRequest.class);

0 commit comments

Comments
 (0)