Skip to content

Commit cfca1ed

Browse files
committed
Fix parsing failure on WITH clauses
1 parent f05cea4 commit cfca1ed

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import io.trino.sql.tree.Node;
4545
import io.trino.sql.tree.NodeLocation;
4646
import io.trino.sql.tree.QualifiedName;
47+
import io.trino.sql.tree.Query;
4748
import io.trino.sql.tree.RenameMaterializedView;
4849
import io.trino.sql.tree.RenameSchema;
4950
import io.trino.sql.tree.RenameTable;
@@ -59,6 +60,7 @@
5960
import io.trino.sql.tree.Statement;
6061
import io.trino.sql.tree.Table;
6162
import io.trino.sql.tree.TableFunctionInvocation;
63+
import io.trino.sql.tree.WithQuery;
6264
import jakarta.servlet.http.HttpServletRequest;
6365
import jakarta.ws.rs.HttpMethod;
6466

@@ -68,12 +70,14 @@
6870
import java.util.ArrayList;
6971
import java.util.Arrays;
7072
import java.util.Enumeration;
73+
import java.util.HashSet;
7174
import java.util.List;
7275
import java.util.Map;
7376
import java.util.Optional;
7477
import java.util.Set;
7578
import java.util.stream.Collectors;
7679

80+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
7781
import static com.google.common.io.BaseEncoding.base64Url;
7882
import static io.airlift.json.JsonCodec.jsonCodec;
7983
import static java.lang.Math.toIntExact;
@@ -90,6 +94,7 @@ public class TrinoQueryProperties
9094
private String queryType = "";
9195
private String resourceGroupQueryType = "";
9296
private Set<QualifiedName> tables = ImmutableSet.of();
97+
private final Set<QualifiedName> temporaryTables = new HashSet<>();
9398
private final Optional<String> defaultCatalog;
9499
private final Optional<String> defaultSchema;
95100
private Set<String> catalogs = ImmutableSet.of();
@@ -201,12 +206,17 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
201206

202207
getNames(statement, tableBuilder, catalogBuilder, schemaBuilder, catalogSchemaBuilder);
203208
tables = tableBuilder.build();
204-
catalogBuilder.addAll(tables.stream().map(q -> q.getParts().getFirst()).iterator());
209+
210+
Set<QualifiedName> filteredTables = tables.stream()
211+
.filter(table -> !temporaryTables.contains(table))
212+
.collect(toImmutableSet());
213+
214+
catalogBuilder.addAll(filteredTables.stream().map(q -> q.getParts().getFirst()).iterator());
205215
catalogs = catalogBuilder.build();
206-
schemaBuilder.addAll(tables.stream().map(q -> q.getParts().get(1)).iterator());
216+
schemaBuilder.addAll(filteredTables.stream().map(q -> q.getParts().get(1)).iterator());
207217
schemas = schemaBuilder.build();
208218
catalogSchemaBuilder.addAll(
209-
tables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
219+
filteredTables.stream().map(qualifiedName -> format("%s.%s", qualifiedName.getParts().getFirst(), qualifiedName.getParts().get(1))).iterator());
210220
catalogSchemas = catalogSchemaBuilder.build();
211221
}
212222
catch (IOException e) {
@@ -338,6 +348,7 @@ private void getNames(Node node, ImmutableSet.Builder<QualifiedName> tableBuilde
338348
case SetViewAuthorization s -> tableBuilder.add(qualifyName(s.getSource()));
339349
case Table s -> tableBuilder.add(qualifyName(s.getName()));
340350
case TableFunctionInvocation s -> tableBuilder.add(qualifyName(s.getName()));
351+
case WithQuery withQuery -> temporaryTables.add(QualifiedName.of(withQuery.getName().getValue()));
341352
default -> {}
342353
}
343354

@@ -385,6 +396,12 @@ private QualifiedName qualifyName(QualifiedName table)
385396
throws RequestParsingException
386397
{
387398
List<String> tableParts = table.getParts();
399+
400+
// ignore temporary tables created by WITH clause as they can have various table parts
401+
if (temporaryTables.contains(table)) {
402+
return table;
403+
}
404+
388405
return switch (tableParts.size()) {
389406
case 1 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), defaultSchema.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst());
390407
case 2 -> QualifiedName.of(defaultCatalog.orElseThrow(this::unsetDefaultExceptionSupplier), tableParts.getFirst(), tableParts.get(1));

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestRoutingGroupSelector.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.util.stream.Stream;
4141

4242
import static io.trino.gateway.ha.router.RoutingGroupSelector.ROUTING_GROUP_HEADER;
43+
import static java.lang.String.format;
4344
import static java.nio.charset.StandardCharsets.UTF_8;
4445
import static org.assertj.core.api.Assertions.assertThat;
4546
import static org.mockito.Mockito.mock;
@@ -404,6 +405,36 @@ void testTrinoQueryPropertiesTableExtraction(String query, Set<String> catalogs,
404405
assertThat(trinoQueryProperties.getCatalogs()).isEqualTo(catalogs);
405406
}
406407

408+
@Test
409+
void testWithQueryNameExcluded()
410+
throws IOException
411+
{
412+
String query = """
413+
WITH dos AS (SELECT c1 from cat.schem.tbl1),
414+
uno as (SELECT c1 FROM dos)
415+
SELECT c1 FROM uno, dos
416+
""";
417+
BufferedReader bufferedReader = new BufferedReader(new StringReader(query));
418+
HttpServletRequest mockRequestWithDefaults = prepareMockRequest();
419+
when(mockRequestWithDefaults.getReader()).thenReturn(bufferedReader);
420+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_CATALOG_HEADER_NAME)).thenReturn(DEFAULT_CATALOG);
421+
when(mockRequestWithDefaults.getHeader(TrinoQueryProperties.TRINO_SCHEMA_HEADER_NAME)).thenReturn(DEFAULT_SCHEMA);
422+
423+
TrinoQueryProperties trinoQueryPropertiesWithDefaults = new TrinoQueryProperties(mockRequestWithDefaults, requestAnalyzerConfig);
424+
425+
assertThat(trinoQueryPropertiesWithDefaults.tablesContains(format("%s.%s.%s", DEFAULT_CATALOG, DEFAULT_SCHEMA, "uno"))).isFalse();
426+
assertThat(trinoQueryPropertiesWithDefaults.tablesContains(format("%s.%s.%s", DEFAULT_CATALOG, DEFAULT_SCHEMA, "does"))).isFalse();
427+
assertThat(trinoQueryPropertiesWithDefaults.tablesContains("cat.schem.tbl1")).isTrue();
428+
429+
HttpServletRequest mockRequestNoDefaults = prepareMockRequest();
430+
bufferedReader = new BufferedReader(new StringReader(query));
431+
when(mockRequestNoDefaults.getReader()).thenReturn(bufferedReader);
432+
433+
TrinoQueryProperties trinoQueryPropertiesNoDefaults = new TrinoQueryProperties(mockRequestNoDefaults, requestAnalyzerConfig);
434+
assertThat(trinoQueryPropertiesNoDefaults.getTables()).hasSize(3);
435+
assertThat(trinoQueryPropertiesNoDefaults.tablesContains("cat.schem.tbl1")).isTrue();
436+
}
437+
407438
private HttpServletRequest prepareMockRequest()
408439
{
409440
HttpServletRequest mockRequest = mock(HttpServletRequest.class);

0 commit comments

Comments
 (0)