Skip to content

Commit e4d4db9

Browse files
yiqianginYiqiang Dingwmoustafa
authored
Coral-Trino: Migrate function operator transformers defined in CalciteTrinoUDFMap from RelNode layer to SqlNode layer (#349)
* Migrate standard UDF operator transformers based on JSON infra from RelNode layer to SqlNode layer * address comments * fixing a typo of the class name * address comments * address comments * address comments * fix a typo * address comments * adding another constructor in LinkedInOperatorBasedSqlCallTransformer * Simplify coral-trino transformations * fix the regression test failures * address comments * add link of a class in comments * fix the regression test failures * Revert "fix the regression test failures" This reverts commit 440aa1c. * fix the regression test failures * rename a function and add some javadoc * add {@link} in javadoc of a function --------- Co-authored-by: Yiqiang Ding <[email protected]> Co-authored-by: Walaa Eldin Moustafa <[email protected]>
1 parent 0a90624 commit e4d4db9

File tree

17 files changed

+675
-955
lines changed

17 files changed

+675
-955
lines changed

coral-common/src/main/java/com/linkedin/coral/common/calcite/CalciteUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2021-2022 LinkedIn Corporation. All rights reserved.
2+
* Copyright 2021-2023 LinkedIn Corporation. All rights reserved.
33
* Licensed under the BSD-2 Clause license.
44
* See LICENSE in the project root for license information.
55
*/
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
import java.util.regex.Pattern;
13+
import java.util.stream.Collectors;
14+
15+
import com.google.gson.JsonArray;
16+
import com.google.gson.JsonElement;
17+
import com.google.gson.JsonObject;
18+
import com.google.gson.JsonParser;
19+
import com.google.gson.JsonPrimitive;
20+
21+
import org.apache.calcite.sql.SqlCall;
22+
import org.apache.calcite.sql.SqlIdentifier;
23+
import org.apache.calcite.sql.SqlNode;
24+
import org.apache.calcite.sql.SqlOperator;
25+
import org.apache.calcite.sql.SqlWriter;
26+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
27+
import org.apache.calcite.sql.parser.SqlParserPos;
28+
import org.apache.calcite.sql.type.OperandTypes;
29+
import org.apache.calcite.sql.type.ReturnTypes;
30+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
31+
32+
import com.linkedin.coral.com.google.common.base.Preconditions;
33+
import com.linkedin.coral.common.functions.FunctionReturnTypes;
34+
35+
import static com.linkedin.coral.common.calcite.CalciteUtil.*;
36+
37+
38+
/**
39+
* This class is a subclass of {@link SourceOperatorMatchSqlCallTransformer} which transforms the parts (e.g. name, operator and operands)
40+
* of the target operator according to multiple rules defined in JSON format respectively.
41+
*/
42+
public class JsonTransformSqlCallTransformer extends SourceOperatorMatchSqlCallTransformer {
43+
private static final Map<String, SqlOperator> OP_MAP = new HashMap<>();
44+
45+
// Operators allowed in the transformation
46+
static {
47+
OP_MAP.put("+", SqlStdOperatorTable.PLUS);
48+
OP_MAP.put("-", SqlStdOperatorTable.MINUS);
49+
OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY);
50+
OP_MAP.put("/", SqlStdOperatorTable.DIVIDE);
51+
OP_MAP.put("^", SqlStdOperatorTable.POWER);
52+
OP_MAP.put("%", SqlStdOperatorTable.MOD);
53+
OP_MAP.put("date", new SqlUserDefinedFunction(new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null,
54+
OperandTypes.STRING, null, null));
55+
OP_MAP.put("timestamp", new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO),
56+
FunctionReturnTypes.TIMESTAMP, null, OperandTypes.STRING, null, null) {
57+
@Override
58+
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
59+
// for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)`
60+
Preconditions.checkState(call.operandCount() == 1);
61+
final SqlWriter.Frame frame = writer.startFunCall("CAST");
62+
call.operand(0).unparse(writer, 0, 0);
63+
writer.sep("AS");
64+
writer.literal("TIMESTAMP");
65+
writer.endFunCall(frame);
66+
}
67+
});
68+
OP_MAP.put("hive_pattern_to_trino",
69+
new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO),
70+
FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null));
71+
}
72+
73+
public static final String OPERATOR = "op";
74+
public static final String OPERANDS = "operands";
75+
/**
76+
* For input node:
77+
* - input equals 0 refers to the result
78+
* - input great than 0 refers to the index of source operand (starting from 1)
79+
*/
80+
public static final String INPUT = "input";
81+
public static final String VALUE = "value";
82+
public static final String REGEX = "regex";
83+
public static final String NAME = "name";
84+
85+
private final SqlOperator targetOperator;
86+
public List<JsonObject> operandTransformers;
87+
public JsonObject resultTransformer;
88+
public List<JsonObject> operatorTransformers;
89+
90+
public JsonTransformSqlCallTransformer(String fromOperatorName, int numOperands, SqlOperator targetOperator) {
91+
super(fromOperatorName, numOperands);
92+
this.targetOperator = targetOperator;
93+
}
94+
95+
public JsonTransformSqlCallTransformer(SqlOperator coralOp, int numOperands, String targetOpName,
96+
String operandTransformers, String resultTransformer, String operatorTransformers) {
97+
this(coralOp.getName(), numOperands, createSqlOperator(targetOpName, coralOp.getReturnTypeInference()));
98+
if (operandTransformers != null) {
99+
this.operandTransformers = parseJsonObjectsFromString(operandTransformers);
100+
}
101+
if (resultTransformer != null) {
102+
this.resultTransformer = new JsonParser().parse(resultTransformer).getAsJsonObject();
103+
}
104+
if (operatorTransformers != null) {
105+
this.operatorTransformers = parseJsonObjectsFromString(operatorTransformers);
106+
}
107+
}
108+
109+
@Override
110+
protected boolean condition(SqlCall sqlCall) {
111+
return sourceOpName.equalsIgnoreCase(sqlCall.getOperator().getName())
112+
&& sqlCall.getOperandList().size() == numOperands;
113+
}
114+
115+
@Override
116+
protected SqlCall transform(SqlCall sqlCall) {
117+
List<SqlNode> sourceOperands = sqlCall.getOperandList();
118+
final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands);
119+
if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) {
120+
String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(","));
121+
throw new IllegalArgumentException(
122+
String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)",
123+
sourceOpName, operands));
124+
}
125+
final List<SqlNode> newOperands = transformOperands(sourceOperands);
126+
final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO);
127+
return (SqlCall) transformResult(newCall, sourceOperands);
128+
}
129+
130+
private List<SqlNode> transformOperands(List<SqlNode> sourceOperands) {
131+
if (operandTransformers == null) {
132+
return sourceOperands;
133+
}
134+
final List<SqlNode> sources = new ArrayList<>();
135+
// Add a dummy expression for input 0
136+
sources.add(null);
137+
sources.addAll(sourceOperands);
138+
final List<SqlNode> results = new ArrayList<>();
139+
for (JsonObject operandTransformer : operandTransformers) {
140+
results.add(transformExpression(operandTransformer, sources));
141+
}
142+
return results;
143+
}
144+
145+
private SqlNode transformResult(SqlNode result, List<SqlNode> sourceOperands) {
146+
if (resultTransformer == null) {
147+
return result;
148+
}
149+
final List<SqlNode> sources = new ArrayList<>();
150+
// Result will be input 0
151+
sources.add(result);
152+
sources.addAll(sourceOperands);
153+
return transformExpression(resultTransformer, sources);
154+
}
155+
156+
/**
157+
* Performs a single transformer.
158+
*/
159+
private SqlNode transformExpression(JsonObject transformer, List<SqlNode> sourceOperands) {
160+
if (transformer.get(OPERATOR) != null) {
161+
final List<SqlNode> inputOperands = new ArrayList<>();
162+
for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) {
163+
if (inputOperand.isJsonObject()) {
164+
inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands));
165+
}
166+
}
167+
final String operatorName = transformer.get(OPERATOR).getAsString();
168+
final SqlOperator op = OP_MAP.get(operatorName);
169+
if (op == null) {
170+
throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation");
171+
}
172+
return createCall(op, inputOperands, SqlParserPos.ZERO);
173+
}
174+
if (transformer.get(INPUT) != null) {
175+
int index = transformer.get(INPUT).getAsInt();
176+
if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) {
177+
throw new IllegalArgumentException(
178+
"Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size());
179+
}
180+
return sourceOperands.get(index);
181+
}
182+
final JsonElement value = transformer.get(VALUE);
183+
if (value == null) {
184+
throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value");
185+
}
186+
if (!value.isJsonPrimitive()) {
187+
throw new IllegalArgumentException("Value should be of primitive type: " + value);
188+
}
189+
190+
final JsonPrimitive primitive = value.getAsJsonPrimitive();
191+
if (primitive.isString()) {
192+
return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO);
193+
}
194+
if (primitive.isBoolean()) {
195+
return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO);
196+
}
197+
if (primitive.isNumber()) {
198+
return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO);
199+
}
200+
201+
throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive);
202+
}
203+
204+
/**
205+
* Returns a SqlOperator with a function name based on the value of the source operands.
206+
*/
207+
private SqlOperator transformTargetOperator(SqlOperator operator, List<SqlNode> sourceOperands) {
208+
if (operatorTransformers == null) {
209+
return operator;
210+
}
211+
212+
for (JsonObject operatorTransformer : operatorTransformers) {
213+
if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) {
214+
throw new IllegalArgumentException(
215+
"JSON node for target operator transformer must have a matcher, input and name");
216+
}
217+
// We use the same convention as operand and result transformers.
218+
// Therefore, we start source index values at index 1 instead of index 0.
219+
// Acceptable index values are set to be [1, size]
220+
int index = operatorTransformer.get(INPUT).getAsInt() - 1;
221+
if (index < 0 || index >= sourceOperands.size()) {
222+
throw new IllegalArgumentException(
223+
String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size()));
224+
}
225+
String functionName = operatorTransformer.get(NAME).getAsString();
226+
if (functionName.isEmpty()) {
227+
throw new IllegalArgumentException("JSON node for transformation must have a non-empty name");
228+
}
229+
String matcher = operatorTransformer.get(REGEX).getAsString();
230+
231+
if (Pattern.matches(matcher, sourceOperands.get(index).toString())) {
232+
return createSqlOperator(functionName, operator.getReturnTypeInference());
233+
}
234+
}
235+
return operator;
236+
}
237+
238+
/**
239+
* Creates an ArrayList of JsonObjects from a string input.
240+
* The input string must be a serialized JSON array.
241+
*/
242+
private static List<JsonObject> parseJsonObjectsFromString(String s) {
243+
List<JsonObject> objects = new ArrayList<>();
244+
JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray();
245+
for (JsonElement object : transformerArray) {
246+
objects.add(object.getAsJsonObject());
247+
}
248+
return objects;
249+
}
250+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import org.apache.calcite.sql.SqlCall;
9+
import org.apache.calcite.sql.SqlNodeList;
10+
import org.apache.calcite.sql.SqlOperator;
11+
import org.apache.calcite.sql.parser.SqlParserPos;
12+
13+
14+
/**
15+
* This class is a subclass of {@link SourceOperatorMatchSqlCallTransformer} which transform the operator name
16+
* from the name of the source operator to the name of the target operator
17+
*/
18+
public class OperatorRenameSqlCallTransformer extends SourceOperatorMatchSqlCallTransformer {
19+
private SqlOperator sourceOperator;
20+
private String targetOpName;
21+
22+
public OperatorRenameSqlCallTransformer(SqlOperator sourceOperator, int numOperands, String targetOpName) {
23+
super(sourceOperator.getName(), numOperands);
24+
this.sourceOperator = sourceOperator;
25+
this.targetOpName = targetOpName;
26+
}
27+
28+
@Override
29+
protected SqlCall transform(SqlCall sqlCall) {
30+
return createSqlOperator(targetOpName, sourceOperator.getReturnTypeInference())
31+
.createCall(new SqlNodeList(sqlCall.getOperandList(), SqlParserPos.ZERO));
32+
}
33+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import org.apache.calcite.sql.SqlCall;
9+
10+
import static com.linkedin.coral.common.calcite.CalciteUtil.*;
11+
12+
13+
/**
14+
* This class is a subclass of {@link SqlCallTransformer} which transforms a function operator on SqlNode layer
15+
* if the signature of the operator to be transformed, including both the name and the number of operands,
16+
* matches the target values in the condition function.
17+
*/
18+
public abstract class SourceOperatorMatchSqlCallTransformer extends SqlCallTransformer {
19+
protected final String sourceOpName;
20+
protected final int numOperands;
21+
22+
public SourceOperatorMatchSqlCallTransformer(String sourceOpName, int numOperands) {
23+
this.sourceOpName = sourceOpName;
24+
this.numOperands = numOperands;
25+
}
26+
27+
@Override
28+
protected boolean condition(SqlCall sqlCall) {
29+
return sourceOpName.equalsIgnoreCase(sqlCall.getOperator().getName())
30+
&& sqlCall.getOperandList().size() == numOperands;
31+
}
32+
}

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
import java.util.ArrayList;
99
import java.util.List;
1010

11+
import com.google.common.collect.ImmutableList;
12+
1113
import org.apache.calcite.rel.type.RelDataType;
1214
import org.apache.calcite.sql.SqlCall;
15+
import org.apache.calcite.sql.SqlIdentifier;
1316
import org.apache.calcite.sql.SqlNode;
1417
import org.apache.calcite.sql.SqlNodeList;
18+
import org.apache.calcite.sql.SqlOperator;
1519
import org.apache.calcite.sql.SqlSelect;
20+
import org.apache.calcite.sql.parser.SqlParserPos;
21+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
22+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
1623
import org.apache.calcite.sql.validate.SqlValidator;
1724

1825

@@ -32,9 +39,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
3239
}
3340

3441
/**
35-
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
42+
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
3643
*/
37-
protected abstract boolean predicate(SqlCall sqlCall);
44+
protected abstract boolean condition(SqlCall sqlCall);
3845

3946
/**
4047
* Implementation of the transformation, returns the transformed SqlCall
@@ -49,7 +56,7 @@ public SqlCall apply(SqlCall sqlCall) {
4956
if (sqlCall instanceof SqlSelect) {
5057
this.topSelectNodes.add((SqlSelect) sqlCall);
5158
}
52-
if (predicate(sqlCall)) {
59+
if (condition(sqlCall)) {
5360
return transform(sqlCall);
5461
} else {
5562
return sqlCall;
@@ -96,4 +103,12 @@ protected RelDataType getRelDataType(SqlNode sqlNode) {
96103
}
97104
throw new RuntimeException("Failed to derive the RelDataType for SqlNode " + sqlNode);
98105
}
106+
107+
/**
108+
* This function creates a {@link SqlOperator} for a function with the function name and return type inference.
109+
*/
110+
protected static SqlOperator createSqlOperator(String functionName, SqlReturnTypeInference typeInference) {
111+
SqlIdentifier sqlIdentifier = new SqlIdentifier(ImmutableList.of(functionName), SqlParserPos.ZERO);
112+
return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
113+
}
99114
}

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
/**
1616
* Container for SqlCallTransformer
1717
*/
18-
public class SqlCallTransformers {
18+
public final class SqlCallTransformers {
1919
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;
2020

2121
SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {

coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) {
3131
}
3232

3333
@Override
34-
public boolean predicate(SqlCall sqlCall) {
34+
public boolean condition(SqlCall sqlCall) {
3535
if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) {
3636
final SqlNode columnNode = sqlCall.getOperandList().get(0);
3737
return getRelDataType(columnNode) instanceof ArraySqlType;

0 commit comments

Comments
 (0)