Skip to content

Commit 104a68c

Browse files
committed
Coral-Spark: Migrate some operator transformers from RelNode layer to SqlNode layer
1 parent 7c23b8d commit 104a68c

File tree

20 files changed

+803
-481
lines changed

20 files changed

+803
-481
lines changed

coral-common/src/main/java/com/linkedin/coral/common/calcite/CalciteUtil.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
/**
2-
* Copyright 2021-2022 LinkedIn Corporation. All rights reserved.
2+
* Copyright 2021-2023 LinkedIn Corporation. All rights reserved.
33
* Licensed under the BSD-2 Clause license.
44
* See LICENSE in the project root for license information.
55
*/
66
package com.linkedin.coral.common.calcite;
77

88
import java.util.*;
99

10+
import com.google.common.collect.ImmutableList;
11+
1012
import org.apache.calcite.avatica.util.Casing;
1113
import org.apache.calcite.sql.*;
1214
import org.apache.calcite.sql.parser.SqlParseException;
1315
import org.apache.calcite.sql.parser.SqlParser;
1416
import org.apache.calcite.sql.parser.SqlParserPos;
17+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1518
import org.apache.calcite.sql.validate.SqlConformanceEnum;
19+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
1620
import org.slf4j.Logger;
1721
import org.slf4j.LoggerFactory;
1822

@@ -123,4 +127,9 @@ public static String quoteReservedWords(String s) {
123127
s = s.replaceAll("(^|\\W)rank($|\\W)", "$1\"rank\"$2");
124128
return s;
125129
}
130+
131+
public static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) {
132+
SqlIdentifier sqlIdentifier = new SqlIdentifier(ImmutableList.of(functionName), SqlParserPos.ZERO);
133+
return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
134+
}
126135
}
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.List;
11+
import java.util.Map;
12+
import java.util.regex.Pattern;
13+
import java.util.stream.Collectors;
14+
15+
import javax.annotation.Nonnull;
16+
import javax.annotation.Nullable;
17+
18+
import com.google.gson.JsonArray;
19+
import com.google.gson.JsonElement;
20+
import com.google.gson.JsonObject;
21+
import com.google.gson.JsonParser;
22+
import com.google.gson.JsonPrimitive;
23+
24+
import org.apache.calcite.sql.SqlCall;
25+
import org.apache.calcite.sql.SqlIdentifier;
26+
import org.apache.calcite.sql.SqlNode;
27+
import org.apache.calcite.sql.SqlOperator;
28+
import org.apache.calcite.sql.SqlWriter;
29+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
30+
import org.apache.calcite.sql.parser.SqlParserPos;
31+
import org.apache.calcite.sql.type.OperandTypes;
32+
import org.apache.calcite.sql.type.ReturnTypes;
33+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
34+
35+
import com.linkedin.coral.com.google.common.base.Preconditions;
36+
import com.linkedin.coral.common.functions.FunctionReturnTypes;
37+
38+
import static com.linkedin.coral.common.calcite.CalciteUtil.*;
39+
40+
41+
/**
42+
* This class is a subclass of SqlCallTransformer which transforms a function operator on SqlNode layer
43+
* if the signature of the operator to be transformed, including both the name and the number of operands,
44+
* matches the target values in the condition function.
45+
*/
46+
public class OperatorBasedSqlCallTransformer extends SqlCallTransformer {
47+
private static final Map<String, SqlOperator> OP_MAP = new HashMap<>();
48+
49+
// Operators allowed in the transformation
50+
static {
51+
OP_MAP.put("+", SqlStdOperatorTable.PLUS);
52+
OP_MAP.put("-", SqlStdOperatorTable.MINUS);
53+
OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY);
54+
OP_MAP.put("/", SqlStdOperatorTable.DIVIDE);
55+
OP_MAP.put("^", SqlStdOperatorTable.POWER);
56+
OP_MAP.put("%", SqlStdOperatorTable.MOD);
57+
OP_MAP.put("date", new SqlUserDefinedFunction(new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null,
58+
OperandTypes.STRING, null, null));
59+
OP_MAP.put("timestamp", new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO),
60+
FunctionReturnTypes.TIMESTAMP, null, OperandTypes.STRING, null, null) {
61+
@Override
62+
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
63+
// for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)`
64+
Preconditions.checkState(call.operandCount() == 1);
65+
final SqlWriter.Frame frame = writer.startFunCall("CAST");
66+
call.operand(0).unparse(writer, 0, 0);
67+
writer.sep("AS");
68+
writer.literal("TIMESTAMP");
69+
writer.endFunCall(frame);
70+
}
71+
});
72+
OP_MAP.put("hive_pattern_to_trino",
73+
new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO),
74+
FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null));
75+
}
76+
77+
public static final String OPERATOR = "op";
78+
public static final String OPERANDS = "operands";
79+
/**
80+
* For input node:
81+
* - input equals 0 refers to the result
82+
* - input great than 0 refers to the index of source operand (starting from 1)
83+
*/
84+
public static final String INPUT = "input";
85+
public static final String VALUE = "value";
86+
public static final String REGEX = "regex";
87+
public static final String NAME = "name";
88+
89+
public final String fromOperatorName;
90+
public final int numOperands;
91+
public final SqlOperator targetOperator;
92+
public List<JsonObject> operandTransformers;
93+
public JsonObject resultTransformer;
94+
public List<JsonObject> operatorTransformers;
95+
96+
public OperatorBasedSqlCallTransformer(@Nonnull String fromOperatorName, int numOperands,
97+
@Nonnull SqlOperator targetOperator, @Nullable String operandTransformers, @Nullable String resultTransformer,
98+
@Nullable String operatorTransformers) {
99+
this.fromOperatorName = fromOperatorName;
100+
this.numOperands = numOperands;
101+
this.targetOperator = targetOperator;
102+
if (operandTransformers != null) {
103+
this.operandTransformers = parseJsonObjectsFromString(operandTransformers);
104+
}
105+
if (resultTransformer != null) {
106+
this.resultTransformer = new JsonParser().parse(resultTransformer).getAsJsonObject();
107+
}
108+
if (operatorTransformers != null) {
109+
this.operatorTransformers = parseJsonObjectsFromString(operatorTransformers);
110+
}
111+
}
112+
113+
public OperatorBasedSqlCallTransformer(@Nonnull SqlOperator coralOp, int numOperands, @Nonnull String targetOpName) {
114+
this(coralOp.getName(), numOperands, createSqlOperatorOfFunction(targetOpName, coralOp.getReturnTypeInference()),
115+
null, null, null);
116+
}
117+
118+
public OperatorBasedSqlCallTransformer(@Nonnull SqlOperator coralOp, int numOperands, @Nonnull String targetOpName,
119+
@Nullable String operandTransformers, @Nullable String resultTransformer, @Nullable String operatorTransformers) {
120+
this(coralOp.getName(), numOperands, createSqlOperatorOfFunction(targetOpName, coralOp.getReturnTypeInference()),
121+
operandTransformers, resultTransformer, operatorTransformers);
122+
}
123+
124+
public SqlOperator getTargetOperator() {
125+
return targetOperator;
126+
}
127+
128+
@Override
129+
protected boolean condition(SqlCall sqlCall) {
130+
return fromOperatorName.equalsIgnoreCase(sqlCall.getOperator().getName())
131+
&& sqlCall.getOperandList().size() == numOperands;
132+
}
133+
134+
@Override
135+
protected SqlCall transform(SqlCall sqlCall) {
136+
List<SqlNode> sourceOperands = sqlCall.getOperandList();
137+
final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands);
138+
if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) {
139+
String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(","));
140+
throw new IllegalArgumentException(
141+
String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)",
142+
fromOperatorName, operands));
143+
}
144+
final List<SqlNode> newOperands = transformOperands(sourceOperands);
145+
final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO);
146+
return (SqlCall) transformResult(newCall, sourceOperands);
147+
}
148+
149+
private List<SqlNode> transformOperands(List<SqlNode> sourceOperands) {
150+
if (operandTransformers == null) {
151+
return sourceOperands;
152+
}
153+
final List<SqlNode> sources = new ArrayList<>();
154+
// Add a dummy expression for input 0
155+
sources.add(null);
156+
sources.addAll(sourceOperands);
157+
final List<SqlNode> results = new ArrayList<>();
158+
for (JsonObject operandTransformer : operandTransformers) {
159+
results.add(transformExpression(operandTransformer, sources));
160+
}
161+
return results;
162+
}
163+
164+
private SqlNode transformResult(SqlNode result, List<SqlNode> sourceOperands) {
165+
if (resultTransformer == null) {
166+
return result;
167+
}
168+
final List<SqlNode> sources = new ArrayList<>();
169+
// Result will be input 0
170+
sources.add(result);
171+
sources.addAll(sourceOperands);
172+
return transformExpression(resultTransformer, sources);
173+
}
174+
175+
/**
176+
* Performs a single transformer.
177+
*/
178+
private SqlNode transformExpression(JsonObject transformer, List<SqlNode> sourceOperands) {
179+
if (transformer.get(OPERATOR) != null) {
180+
final List<SqlNode> inputOperands = new ArrayList<>();
181+
for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) {
182+
if (inputOperand.isJsonObject()) {
183+
inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands));
184+
}
185+
}
186+
final String operatorName = transformer.get(OPERATOR).getAsString();
187+
final SqlOperator op = OP_MAP.get(operatorName);
188+
if (op == null) {
189+
throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation");
190+
}
191+
return createCall(op, inputOperands, SqlParserPos.ZERO);
192+
}
193+
if (transformer.get(INPUT) != null) {
194+
int index = transformer.get(INPUT).getAsInt();
195+
if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) {
196+
throw new IllegalArgumentException(
197+
"Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size());
198+
}
199+
return sourceOperands.get(index);
200+
}
201+
final JsonElement value = transformer.get(VALUE);
202+
if (value == null) {
203+
throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value");
204+
}
205+
if (!value.isJsonPrimitive()) {
206+
throw new IllegalArgumentException("Value should be of primitive type: " + value);
207+
}
208+
209+
final JsonPrimitive primitive = value.getAsJsonPrimitive();
210+
if (primitive.isString()) {
211+
return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO);
212+
}
213+
if (primitive.isBoolean()) {
214+
return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO);
215+
}
216+
if (primitive.isNumber()) {
217+
return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO);
218+
}
219+
220+
throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive);
221+
}
222+
223+
/**
224+
* Returns a SqlOperator with a function name based on the value of the source operands.
225+
*/
226+
private SqlOperator transformTargetOperator(SqlOperator operator, List<SqlNode> sourceOperands) {
227+
if (operatorTransformers == null) {
228+
return operator;
229+
}
230+
231+
for (JsonObject operatorTransformer : operatorTransformers) {
232+
if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) {
233+
throw new IllegalArgumentException(
234+
"JSON node for target operator transformer must have a matcher, input and name");
235+
}
236+
// We use the same convention as operand and result transformers.
237+
// Therefore, we start source index values at index 1 instead of index 0.
238+
// Acceptable index values are set to be [1, size]
239+
int index = operatorTransformer.get(INPUT).getAsInt() - 1;
240+
if (index < 0 || index >= sourceOperands.size()) {
241+
throw new IllegalArgumentException(
242+
String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size()));
243+
}
244+
String functionName = operatorTransformer.get(NAME).getAsString();
245+
if (functionName.isEmpty()) {
246+
throw new IllegalArgumentException("JSON node for transformation must have a non-empty name");
247+
}
248+
String matcher = operatorTransformer.get(REGEX).getAsString();
249+
250+
if (Pattern.matches(matcher, sourceOperands.get(index).toString())) {
251+
return createSqlOperatorOfFunction(functionName, operator.getReturnTypeInference());
252+
}
253+
}
254+
return operator;
255+
}
256+
257+
/**
258+
* Creates an ArrayList of JsonObjects from a string input.
259+
* The input string must be a serialized JSON array.
260+
*/
261+
private static List<JsonObject> parseJsonObjectsFromString(String s) {
262+
List<JsonObject> objects = new ArrayList<>();
263+
JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray();
264+
for (JsonElement object : transformerArray) {
265+
objects.add(object.getAsJsonObject());
266+
}
267+
return objects;
268+
}
269+
}

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
3232
}
3333

3434
/**
35-
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
35+
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
3636
*/
37-
protected abstract boolean predicate(SqlCall sqlCall);
37+
protected abstract boolean condition(SqlCall sqlCall);
3838

3939
/**
4040
* Implementation of the transformation, returns the transformed SqlCall
@@ -49,7 +49,7 @@ public SqlCall apply(SqlCall sqlCall) {
4949
if (sqlCall instanceof SqlSelect) {
5050
this.topSelectNodes.add((SqlSelect) sqlCall);
5151
}
52-
if (predicate(sqlCall)) {
52+
if (condition(sqlCall)) {
5353
return transform(sqlCall);
5454
} else {
5555
return sqlCall;

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
/**
1616
* Container for SqlCallTransformer
1717
*/
18-
public class SqlCallTransformers {
18+
public final class SqlCallTransformers {
1919
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;
2020

2121
SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {

coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package com.linkedin.coral.transformers;
77

88
import java.util.ArrayList;
9+
import java.util.Collections;
910
import java.util.List;
1011
import java.util.Map;
1112

@@ -25,6 +26,7 @@
2526
import org.apache.calcite.rex.RexCorrelVariable;
2627
import org.apache.calcite.rex.RexFieldAccess;
2728
import org.apache.calcite.rex.RexNode;
29+
import org.apache.calcite.rex.RexProgram;
2830
import org.apache.calcite.sql.JoinConditionType;
2931
import org.apache.calcite.sql.JoinType;
3032
import org.apache.calcite.sql.SqlCall;
@@ -43,6 +45,7 @@
4345
import com.linkedin.coral.com.google.common.collect.ImmutableList;
4446
import com.linkedin.coral.com.google.common.collect.ImmutableMap;
4547
import com.linkedin.coral.common.functions.CoralSqlUnnestOperator;
48+
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
4649

4750

4851
/**
@@ -345,4 +348,41 @@ private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rig
345348

346349
return SqlStdOperatorTable.AS.createCall(POS, asOperands);
347350
}
351+
352+
/**
353+
* Override this method to handle the conversion for RelNode `f(x).y.z` where `f` is an UDF, which
354+
* returns a struct containing field `y`, `y` is also a struct containing field `z`.
355+
*
356+
* For this kind of RelNode, Calcite will convert it to a SqlIdentifier directly (check
357+
* {@link org.apache.calcite.rel.rel2sql.SqlImplementor.Context#toSql(RexProgram, RexNode)}),
358+
* which is not aligned with our expectation since we want to apply transformations on `f(x)` with
359+
* {@link com.linkedin.coral.common.transformers.SqlCallTransformer}. Therefore, we override this
360+
* method to convert `f(x)` to a SqlCall, `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT}
361+
* and `y.z` to SqlIdentifier.
362+
*/
363+
@Override
364+
public Context aliasContext(Map<String, RelDataType> aliases, boolean qualified) {
365+
return new AliasContext(INSTANCE, aliases, qualified) {
366+
@Override
367+
public SqlNode toSql(RexProgram program, RexNode rex) {
368+
if (rex.getKind() == SqlKind.FIELD_ACCESS) {
369+
final List<String> accessNames = new ArrayList<>();
370+
RexNode referencedExpr = rex;
371+
// Use the loop to get the top-level struct (`f(x)` in the example above),
372+
// and store the accessed field names ([`z`, `y`] in the example above, needs to be reversed)
373+
while (referencedExpr.getKind() == SqlKind.FIELD_ACCESS) {
374+
accessNames.add(((RexFieldAccess) referencedExpr).getField().getName());
375+
referencedExpr = ((RexFieldAccess) referencedExpr).getReferenceExpr();
376+
}
377+
if (referencedExpr.getKind() == SqlKind.OTHER_FUNCTION) {
378+
SqlNode functionCall = toSql(program, referencedExpr);
379+
Collections.reverse(accessNames);
380+
return FunctionFieldReferenceOperator.DOT.createCall(SqlParserPos.ZERO, functionCall,
381+
new SqlIdentifier(String.join(".", accessNames), POS));
382+
}
383+
}
384+
return super.toSql(program, rex);
385+
}
386+
};
387+
}
348388
}

0 commit comments

Comments
 (0)