Skip to content

Commit 394979a

Browse files
committed
Coral-Spark: Migrate some operator transformers from RelNode layer to SqlNode layer
1 parent 7c23b8d commit 394979a

File tree

21 files changed

+591
-530
lines changed

21 files changed

+591
-530
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import org.apache.calcite.sql.SqlCall;
9+
import org.apache.calcite.sql.SqlNodeList;
10+
import org.apache.calcite.sql.parser.SqlParserPos;
11+
12+
13+
/**
14+
* This class is a subclass of {@link SourceOperatorMatchSqlCallTransformer} which renames the source operator
15+
*/
16+
public class OperatorRenameSqlCallTransformer extends SourceOperatorMatchSqlCallTransformer {
17+
private final String targetOpName;
18+
19+
public OperatorRenameSqlCallTransformer(String sourceOpName, int numOperands, String targetOpName) {
20+
super(sourceOpName, numOperands);
21+
this.targetOpName = targetOpName;
22+
}
23+
24+
@Override
25+
protected SqlCall transform(SqlCall sqlCall) {
26+
return createSqlOperatorOfFunction(targetOpName, sqlCall.getOperator().getReturnTypeInference())
27+
.createCall(new SqlNodeList(sqlCall.getOperandList(), SqlParserPos.ZERO));
28+
}
29+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.common.transformers;
7+
8+
import org.apache.calcite.sql.SqlCall;
9+
10+
11+
/**
12+
* This class is a subclass of {@link SqlCallTransformer} which transforms a function operator on SqlNode layer
13+
* if the signature of the operator to be transformed, including both the name and the number of operands,
14+
* matches the target values in the condition function.
15+
*/
16+
public abstract class SourceOperatorMatchSqlCallTransformer extends SqlCallTransformer {
17+
protected final String sourceOpName;
18+
protected final int numOperands;
19+
20+
public SourceOperatorMatchSqlCallTransformer(String sourceOpName, int numOperands) {
21+
this.sourceOpName = sourceOpName;
22+
this.numOperands = numOperands;
23+
}
24+
25+
@Override
26+
protected boolean condition(SqlCall sqlCall) {
27+
return sourceOpName.equalsIgnoreCase(sqlCall.getOperator().getName())
28+
&& sqlCall.getOperandList().size() == numOperands;
29+
}
30+
}

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
import java.util.ArrayList;
99
import java.util.List;
1010

11+
import com.google.common.collect.ImmutableList;
12+
1113
import org.apache.calcite.rel.type.RelDataType;
1214
import org.apache.calcite.sql.SqlCall;
15+
import org.apache.calcite.sql.SqlIdentifier;
1316
import org.apache.calcite.sql.SqlNode;
1417
import org.apache.calcite.sql.SqlNodeList;
18+
import org.apache.calcite.sql.SqlOperator;
1519
import org.apache.calcite.sql.SqlSelect;
20+
import org.apache.calcite.sql.parser.SqlParserPos;
21+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
22+
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
1623
import org.apache.calcite.sql.validate.SqlValidator;
1724

1825

@@ -32,9 +39,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
3239
}
3340

3441
/**
35-
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
42+
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
3643
*/
37-
protected abstract boolean predicate(SqlCall sqlCall);
44+
protected abstract boolean condition(SqlCall sqlCall);
3845

3946
/**
4047
* Implementation of the transformation, returns the transformed SqlCall
@@ -49,7 +56,7 @@ public SqlCall apply(SqlCall sqlCall) {
4956
if (sqlCall instanceof SqlSelect) {
5057
this.topSelectNodes.add((SqlSelect) sqlCall);
5158
}
52-
if (predicate(sqlCall)) {
59+
if (condition(sqlCall)) {
5360
return transform(sqlCall);
5461
} else {
5562
return sqlCall;
@@ -96,4 +103,9 @@ protected RelDataType getRelDataType(SqlNode sqlNode) {
96103
}
97104
throw new RuntimeException("Failed to derive the RelDataType for SqlNode " + sqlNode);
98105
}
106+
107+
protected static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) {
108+
SqlIdentifier sqlIdentifier = new SqlIdentifier(ImmutableList.of(functionName), SqlParserPos.ZERO);
109+
return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
110+
}
99111
}

coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
/**
1616
* Container for SqlCallTransformer
1717
*/
18-
public class SqlCallTransformers {
18+
public final class SqlCallTransformers {
1919
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;
2020

2121
SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {

coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package com.linkedin.coral.transformers;
77

88
import java.util.ArrayList;
9+
import java.util.Collections;
910
import java.util.List;
1011
import java.util.Map;
1112

@@ -25,6 +26,7 @@
2526
import org.apache.calcite.rex.RexCorrelVariable;
2627
import org.apache.calcite.rex.RexFieldAccess;
2728
import org.apache.calcite.rex.RexNode;
29+
import org.apache.calcite.rex.RexProgram;
2830
import org.apache.calcite.sql.JoinConditionType;
2931
import org.apache.calcite.sql.JoinType;
3032
import org.apache.calcite.sql.SqlCall;
@@ -43,6 +45,7 @@
4345
import com.linkedin.coral.com.google.common.collect.ImmutableList;
4446
import com.linkedin.coral.com.google.common.collect.ImmutableMap;
4547
import com.linkedin.coral.common.functions.CoralSqlUnnestOperator;
48+
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
4649

4750

4851
/**
@@ -345,4 +348,41 @@ private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rig
345348

346349
return SqlStdOperatorTable.AS.createCall(POS, asOperands);
347350
}
351+
352+
/**
353+
* Override this method to handle the conversion for RelNode `f(x).y.z` where `f` is an UDF, which
354+
* returns a struct containing field `y`, `y` is also a struct containing field `z`.
355+
*
356+
* Calcite will convert this RelNode to a SqlIdentifier directly (check
357+
* {@link org.apache.calcite.rel.rel2sql.SqlImplementor.Context#toSql(RexProgram, RexNode)}),
358+
* which is not aligned with our expectation since we want to apply transformations on `f(x)` with
359+
* {@link com.linkedin.coral.common.transformers.SqlCallTransformer}. Therefore, we override this
360+
* method to convert `f(x)` to SqlCall, `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT}
361+
* and `y.z` to SqlIdentifier.
362+
*/
363+
@Override
364+
public Context aliasContext(Map<String, RelDataType> aliases, boolean qualified) {
365+
return new AliasContext(INSTANCE, aliases, qualified) {
366+
@Override
367+
public SqlNode toSql(RexProgram program, RexNode rex) {
368+
if (rex.getKind() == SqlKind.FIELD_ACCESS) {
369+
final List<String> accessNames = new ArrayList<>();
370+
RexNode referencedExpr = rex;
371+
// Use the loop to get the top-level struct (`f(x)` in the example above),
372+
// and store the accessed field names ([`z`, `y`] in the example above, needs to be reversed)
373+
while (referencedExpr.getKind() == SqlKind.FIELD_ACCESS) {
374+
accessNames.add(((RexFieldAccess) referencedExpr).getField().getName());
375+
referencedExpr = ((RexFieldAccess) referencedExpr).getReferenceExpr();
376+
}
377+
if (referencedExpr.getKind() == SqlKind.OTHER_FUNCTION) {
378+
SqlNode functionCall = toSql(program, referencedExpr);
379+
Collections.reverse(accessNames);
380+
return FunctionFieldReferenceOperator.DOT.createCall(SqlParserPos.ZERO, functionCall,
381+
new SqlIdentifier(String.join(".", accessNames), POS));
382+
}
383+
}
384+
return super.toSql(program, rex);
385+
}
386+
};
387+
}
348388
}

coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) {
3131
}
3232

3333
@Override
34-
public boolean predicate(SqlCall sqlCall) {
34+
public boolean condition(SqlCall sqlCall) {
3535
if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) {
3636
final SqlNode columnNode = sqlCall.getOperandList().get(0);
3737
return getRelDataType(columnNode) instanceof ArraySqlType;

coral-spark/src/main/java/com/linkedin/coral/spark/BuiltinUDFMap.java

Lines changed: 0 additions & 49 deletions
This file was deleted.

coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package com.linkedin.coral.spark;
77

88
import java.util.List;
9+
import java.util.Set;
910
import java.util.stream.Collectors;
1011

1112
import org.apache.avro.Schema;
@@ -16,6 +17,7 @@
1617
import org.apache.calcite.sql.SqlNode;
1718
import org.apache.calcite.sql.SqlSelect;
1819

20+
import com.linkedin.coral.com.google.common.collect.ImmutableList;
1921
import com.linkedin.coral.spark.containers.SparkRelInfo;
2022
import com.linkedin.coral.spark.containers.SparkUDFInfo;
2123
import com.linkedin.coral.spark.dialect.SparkSqlDialect;
@@ -63,11 +65,11 @@ private CoralSpark(List<String> baseTables, List<SparkUDFInfo> sparkUDFInfoList,
6365
*/
6466
public static CoralSpark create(RelNode irRelNode) {
6567
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
68+
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
6669
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
67-
String sparkSQL = constructSparkSQL(sparkRelNode);
70+
String sparkSQL = constructSparkSQL(sparkRelNode, sparkUDFInfos);
6871
List<String> baseTables = constructBaseTables(sparkRelNode);
69-
List<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfoList();
70-
return new CoralSpark(baseTables, sparkUDFInfos, sparkSQL);
72+
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL);
7173
}
7274

7375
/**
@@ -86,11 +88,11 @@ public static CoralSpark create(RelNode irRelNode, Schema schema) {
8688

8789
private static CoralSpark createWithAlias(RelNode irRelNode, List<String> aliases) {
8890
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
91+
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
8992
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
90-
String sparkSQL = constructSparkSQLWithExplicitAlias(sparkRelNode, aliases);
93+
String sparkSQL = constructSparkSQLWithExplicitAlias(sparkRelNode, aliases, sparkUDFInfos);
9194
List<String> baseTables = constructBaseTables(sparkRelNode);
92-
List<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfoList();
93-
return new CoralSpark(baseTables, sparkUDFInfos, sparkSQL);
95+
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL);
9496
}
9597

9698
/**
@@ -105,21 +107,25 @@ private static CoralSpark createWithAlias(RelNode irRelNode, List<String> aliase
105107
*
106108
* @param sparkRelNode A Spark compatible RelNode
107109
*
110+
* @param sparkUDFInfos A set of Spark UDF information objects
108111
* @return SQL String in Spark SQL dialect which is 'completely expanded'
109112
*/
110-
private static String constructSparkSQL(RelNode sparkRelNode) {
113+
private static String constructSparkSQL(RelNode sparkRelNode, Set<SparkUDFInfo> sparkUDFInfos) {
111114
CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter();
112115
SqlNode coralSqlNode = rel2sql.convert(sparkRelNode);
113-
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter());
116+
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter())
117+
.accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));
114118
SqlNode rewrittenSparkSqlNode = sparkSqlNode.accept(new SparkSqlRewriter());
115119
return rewrittenSparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql();
116120
}
117121

118-
private static String constructSparkSQLWithExplicitAlias(RelNode sparkRelNode, List<String> aliases) {
122+
private static String constructSparkSQLWithExplicitAlias(RelNode sparkRelNode, List<String> aliases,
123+
Set<SparkUDFInfo> sparkUDFInfos) {
119124
CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter();
120125
// Create temporary objects r and rewritten to make debugging easier
121126
SqlNode coralSqlNode = rel2sql.convert(sparkRelNode);
122-
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter());
127+
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter())
128+
.accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));
123129

124130
SqlNode rewritten = sparkSqlNode.accept(new SparkSqlRewriter());
125131
// Use a second pass visit to add explicit alias names,

0 commit comments

Comments
 (0)