Skip to content

Commit b8ed756

Browse files
committed
Serialize procedure
1 parent 6588b66 commit b8ed756

File tree

6 files changed

+280
-173
lines changed

6 files changed

+280
-173
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public static Optional<String> extractQueryIdIfPresent(
119119
if (trinoQueryProperties.procedureNameEquals("system.runtime.kill_query")) {
120120
Expression argument = trinoQueryProperties.getProcedureArguments().getFirst().getValue();
121121
checkArgument(argument instanceof StringLiteral, "Unable to route kill_query procedures where the first argument is not a String Literal");
122-
return Optional.of(((StringLiteral) trinoQueryProperties.getProcedureArguments().getFirst().getValue()).getValue());
122+
return Optional.of(((StringLiteral) argument).getValue());
123123
}
124124
}
125125
return Optional.empty();
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.router;
15+
16+
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonProperty;
18+
import io.trino.sql.tree.CallArgument;
19+
import io.trino.sql.tree.Expression;
20+
import io.trino.sql.tree.Identifier;
21+
22+
import java.util.Optional;
23+
24+
import static java.util.Objects.requireNonNull;
25+
26+
// CallArgument is final, so this class just replicates it. Location is not preserved, since it would require
27+
// additional complexity and is not meaningful in this context
28+
public class SerializableCallArgument
29+
{
30+
private final Optional<Identifier> name;
31+
private final SerializableExpression value;
32+
33+
public SerializableCallArgument(SerializableExpression value)
34+
{
35+
this(Optional.empty(), value);
36+
}
37+
38+
@JsonCreator
39+
public SerializableCallArgument(
40+
@JsonProperty("name") Optional<String> name,
41+
@JsonProperty("value") SerializableExpression value)
42+
{
43+
this.name = requireNonNull(name, "name is null").map(Identifier::new);
44+
this.value = requireNonNull(value, "value is null");
45+
}
46+
47+
public SerializableCallArgument(CallArgument callArgument)
48+
{
49+
this.name = callArgument.getName();
50+
this.value = new SerializableExpression(callArgument.getValue());
51+
}
52+
53+
public CallArgument toCallArgument()
54+
{
55+
return name.map(n -> new CallArgument(n, value)).orElse(new CallArgument(value));
56+
}
57+
58+
@JsonProperty
59+
public Expression getValue()
60+
{
61+
return value;
62+
}
63+
64+
@JsonProperty
65+
public Optional<String> getName()
66+
{
67+
return name.map(Identifier::getValue);
68+
}
69+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.router;
15+
16+
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonProperty;
18+
import com.google.common.collect.ImmutableList;
19+
import io.trino.sql.tree.Expression;
20+
import io.trino.sql.tree.Node;
21+
import io.trino.sql.tree.NodeLocation;
22+
23+
import java.util.List;
24+
import java.util.Objects;
25+
import java.util.Optional;
26+
27+
public class SerializableExpression
28+
extends Expression
29+
{
30+
private final String value;
31+
32+
@JsonCreator
33+
public SerializableExpression(@JsonProperty("value") String value)
34+
{
35+
this(Optional.empty(), value);
36+
}
37+
38+
protected SerializableExpression(Optional<NodeLocation> location, String value)
39+
{
40+
super(location);
41+
this.value = value;
42+
}
43+
44+
public SerializableExpression(Expression expression)
45+
{
46+
super(expression.getLocation());
47+
if (expression instanceof SerializableExpression) {
48+
value = ((SerializableExpression) expression).getValue();
49+
}
50+
else {
51+
this.value = expression.toString();
52+
}
53+
}
54+
55+
@JsonProperty
56+
public String getValue()
57+
{
58+
return value;
59+
}
60+
61+
@Override
62+
public List<? extends Node> getChildren()
63+
{
64+
return ImmutableList.of();
65+
}
66+
67+
@Override
68+
public int hashCode()
69+
{
70+
return value.hashCode();
71+
}
72+
73+
@Override
74+
public boolean equals(Object obj)
75+
{
76+
if (this == obj) {
77+
return true;
78+
}
79+
else if (obj != null && this.getClass() == obj.getClass()) {
80+
SerializableExpression that = (SerializableExpression) obj;
81+
return Objects.equals(this.value, that.value);
82+
}
83+
else {
84+
return false;
85+
}
86+
}
87+
}

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoQueryProperties.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package io.trino.gateway.ha.router;
1515

1616
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonIgnore;
1718
import com.fasterxml.jackson.annotation.JsonProperty;
1819
import com.fasterxml.jackson.core.JsonGenerator;
1920
import com.fasterxml.jackson.databind.SerializerProvider;
@@ -119,6 +120,8 @@ public TrinoQueryProperties(
119120
@JsonProperty("catalogs") Set<String> catalogs,
120121
@JsonProperty("schemas") Set<String> schemas,
121122
@JsonProperty("catalogSchemas") Set<String> catalogSchemas,
123+
@JsonProperty("procedure") Optional<String> procedure,
124+
@JsonProperty("procedureArguments") SerializableCallArgument[] procedureArguments,
122125
@JsonProperty("isNewQuerySubmission") boolean isNewQuerySubmission,
123126
@JsonProperty("errorMessage") Optional<String> errorMessage)
124127
{
@@ -132,6 +135,8 @@ public TrinoQueryProperties(
132135
this.schemas = requireNonNullElse(schemas, ImmutableSet.of());
133136
this.catalogSchemas = requireNonNullElse(catalogSchemas, ImmutableSet.of());
134137
this.isNewQuerySubmission = isNewQuerySubmission;
138+
this.procedureArguments = Arrays.stream(requireNonNullElse(procedureArguments, new SerializableCallArgument[] {})).map(SerializableCallArgument::toCallArgument).toList();
139+
this.procedure = procedure.map(procedureName -> new Call(parseIdentifierStringToQualifiedName(procedureName), this.procedureArguments));
135140
this.errorMessage = requireNonNullElse(errorMessage, Optional.empty());
136141
isClientsUseV2Format = false;
137142
maxBodySize = -1;
@@ -284,6 +289,7 @@ private void visitNode(Node node, ImmutableSet.Builder<QualifiedName> tableBuild
284289
case AddColumn s -> tableBuilder.add(qualifyName(s.getName()));
285290
case Analyze s -> tableBuilder.add(qualifyName(s.getTableName()));
286291
case Call call -> {
292+
setCatalogAndSchemaNameFromSchemaQualifiedName(Optional.of(call.getName()), catalogBuilder, schemaBuilder, catalogSchemaBuilder);
287293
procedure = Optional.of(call);
288294
procedureArguments = call.getArguments();
289295
return;
@@ -497,11 +503,11 @@ public boolean tablesContains(String testName)
497503
}
498504
}
499505

500-
public boolean procedureNameEquals(String testName)
506+
public boolean procedureNameEquals(String name)
501507
{
502508
return procedure.map(p -> {
503509
try {
504-
return qualifyName(p.getName()).equals(parseIdentifierStringToQualifiedName(testName));
510+
return qualifyName(p.getName()).equals(parseIdentifierStringToQualifiedName(name));
505511
}
506512
catch (RequestParsingException e) {
507513
return false;
@@ -551,16 +557,30 @@ public Optional<String> getErrorMessage()
551557
return errorMessage;
552558
}
553559

560+
@JsonIgnore
554561
public Optional<Call> getProcedure()
555562
{
556563
return procedure;
557564
}
558565

566+
@JsonIgnore
559567
public List<CallArgument> getProcedureArguments()
560568
{
561569
return procedureArguments;
562570
}
563571

572+
@JsonProperty("procedure")
573+
public Optional<String> getProcedureName()
574+
{
575+
return procedure.map(p -> p.getName().toString());
576+
}
577+
578+
@JsonProperty("procedureArguments")
579+
public List<SerializableCallArgument> getSerializableProcedureArguments()
580+
{
581+
return procedureArguments.stream().map(SerializableCallArgument::new).toList();
582+
}
583+
564584
public static class AlternateStatementRequestBodyFormat
565585
{
566586
// Based on https://github.com/trinodb/trino/wiki/trino-v2-client-protocol, without session

0 commit comments

Comments
 (0)