Skip to content

Commit ef058ef

Browse files
authored
Coral-Spark: Migrate 'CAST(ROW: RECORD_TYPE)' and 'extract_union' transformations from RelNode to SqlNode layer (#354)
1 parent cdbdd4e commit ef058ef

File tree

5 files changed

+135
-78
lines changed

5 files changed

+135
-78
lines changed

coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/CoalesceStructUtility.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
2+
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
33
* Licensed under the BSD-2 Clause license.
44
* See LICENSE in the project root for license information.
55
*/
@@ -22,7 +22,7 @@
2222

2323
/**
2424
* A utility class to coalesce the {@link RelDataType} of struct between Trino's representation and
25-
* hive's extract_union UDF's representation on exploded union.
25+
* Hive/Spark's extract_union UDF's representation on exploded union.
2626
*
2727
*/
2828
public class CoalesceStructUtility {

coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer;
1616
import com.linkedin.coral.common.transformers.SqlCallTransformers;
1717
import com.linkedin.coral.spark.containers.SparkUDFInfo;
18+
import com.linkedin.coral.spark.transformers.CastToNamedStructTransformer;
19+
import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer;
1820
import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer;
1921
import com.linkedin.coral.spark.transformers.TransportUDFTransformer;
2022

@@ -153,7 +155,13 @@ public CoralToSparkSqlCallConverter(Set<SparkUDFInfo> sparkUDFInfos) {
153155
new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.CARDINALITY, 1, "size"),
154156

155157
// Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above
156-
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos));
158+
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos),
159+
160+
// Transform `CAST(ROW: RECORD_TYPE)` to `named_struct`
161+
new CastToNamedStructTransformer(),
162+
163+
// Transform `extract_union` to `coalesce_struct`
164+
new ExtractUnionFunctionTransformer(sparkUDFInfos));
157165
}
158166

159167
@Override

coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java

Lines changed: 2 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,19 @@
3131
import org.apache.calcite.rel.logical.LogicalUnion;
3232
import org.apache.calcite.rel.logical.LogicalValues;
3333
import org.apache.calcite.rel.type.RelDataType;
34-
import org.apache.calcite.rel.type.RelDataTypeField;
35-
import org.apache.calcite.rel.type.RelRecordType;
3634
import org.apache.calcite.rex.RexBuilder;
3735
import org.apache.calcite.rex.RexCall;
3836
import org.apache.calcite.rex.RexLiteral;
3937
import org.apache.calcite.rex.RexNode;
4038
import org.apache.calcite.rex.RexShuttle;
4139
import org.apache.calcite.rex.RexUtil;
42-
import org.apache.calcite.sql.SqlIdentifier;
4340
import org.apache.calcite.sql.SqlKind;
44-
import org.apache.calcite.sql.SqlOperator;
4541
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
46-
import org.apache.calcite.sql.parser.SqlParserPos;
4742
import org.apache.calcite.sql.type.ArraySqlType;
48-
import org.apache.calcite.sql.type.SqlReturnTypeInference;
4943
import org.apache.calcite.sql.type.SqlTypeName;
50-
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
5144

5245
import com.linkedin.coral.com.google.common.collect.ImmutableList;
53-
import com.linkedin.coral.com.google.common.collect.Lists;
5446
import com.linkedin.coral.common.functions.GenericProjectFunction;
55-
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
56-
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;
5747
import com.linkedin.coral.spark.containers.SparkRelInfo;
5848
import com.linkedin.coral.spark.containers.SparkUDFInfo;
5949
import com.linkedin.coral.spark.utils.RelDataTypeToHiveTypeStringConverter;
@@ -200,9 +190,8 @@ public RexNode visitCall(RexCall call) {
200190
RexCall updatedCall = (RexCall) super.visitCall(call);
201191

202192
RexNode convertToNewNode =
203-
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertToNamedStruct(updatedCall).orElseGet(
204-
() -> convertFuzzyUnionGenericProject(updatedCall).orElseGet(() -> swapExtractUnionFunction(updatedCall)
205-
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)))));
193+
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertFuzzyUnionGenericProject(updatedCall)
194+
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)));
206195

207196
return convertToNewNode;
208197
}
@@ -227,25 +216,6 @@ private Optional<RexNode> convertToZeroBasedArrayIndex(RexCall call) {
227216
return Optional.empty();
228217
}
229218

230-
// Convert CAST(ROW: RECORD_TYPE) to named_struct
231-
private Optional<RexNode> convertToNamedStruct(RexCall call) {
232-
if (call.getOperator().equals(SqlStdOperatorTable.CAST)) {
233-
RexNode operand = call.getOperands().get(0);
234-
if (operand instanceof RexCall && ((RexCall) operand).getOperator().equals(SqlStdOperatorTable.ROW)) {
235-
RelRecordType recordType = (RelRecordType) call.getType();
236-
List<RexNode> rowOperands = ((RexCall) operand).getOperands();
237-
List<RexNode> newOperands = new ArrayList<>(recordType.getFieldCount() * 2);
238-
for (int i = 0; i < recordType.getFieldCount(); i += 1) {
239-
RelDataTypeField dataTypeField = recordType.getFieldList().get(i);
240-
newOperands.add(rexBuilder.makeLiteral(dataTypeField.getKey()));
241-
newOperands.add(rexBuilder.makeCast(dataTypeField.getType(), rowOperands.get(i)));
242-
}
243-
return Optional.of(rexBuilder.makeCall(call.getType(), new HiveNamedStructFunction(), newOperands));
244-
}
245-
}
246-
return Optional.empty();
247-
}
248-
249219
/**
250220
* Add the schema to GenericProject in Fuzzy Union
251221
* @param call a given RexCall
@@ -270,44 +240,6 @@ private Optional<RexNode> convertFuzzyUnionGenericProject(RexCall call) {
270240
return Optional.empty();
271241
}
272242

273-
/**
274-
* Instead of leaving extract_union visible to (Hive)Spark, since we adopted the new exploded struct schema(
275-
* a.k.a struct_tr) that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
276-
* we need to swap the reference of "extract_union" to a new UDF that is coalescing the difference between
277-
* struct_tr and struct_ex.
278-
*
279-
* See com.linkedin.coral.common.functions.FunctionReturnTypes#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY
280-
* and its comments for more details.
281-
*
282-
* @param call the original extract_union function call.
283-
* @return A new {@link RexNode} replacing the original extract_union call.
284-
*/
285-
private Optional<RexNode> swapExtractUnionFunction(RexCall call) {
286-
if (call.getOperator().getName().equalsIgnoreCase("extract_union")) {
287-
// Only when there's a necessity to register coalesce_struct UDF
288-
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
289-
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
290-
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));
291-
292-
// one arg case: extract_union(field_name)
293-
if (call.getOperands().size() == 1) {
294-
return Optional.of(rexBuilder.makeCall(
295-
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
296-
call.getOperands()));
297-
}
298-
// two arg case: extract_union(field_name, ordinal)
299-
else if (call.getOperands().size() == 2) {
300-
int ordinal = ((RexLiteral) call.getOperands().get(1)).getValueAs(Integer.class) + 1;
301-
List<RexNode> operandsCopy = Lists.newArrayList(call.getOperands());
302-
operandsCopy.set(1, rexBuilder.makeExactLiteral(new BigDecimal(ordinal)));
303-
return Optional.of(rexBuilder.makeCall(
304-
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
305-
operandsCopy));
306-
}
307-
}
308-
return Optional.empty();
309-
}
310-
311243
/**
312244
* Calcite entails the nullability of an expression by casting it to the correct nullable type.
313245
* However, for complex types like ARRAY<STRING NOT NULL> (element non-nullable, but top-level nullable),
@@ -336,10 +268,5 @@ private Optional<RexNode> removeCastToEnsureCorrectNullability(RexCall call) {
336268
}
337269
return Optional.empty();
338270
}
339-
340-
private static SqlOperator createUDF(String udfName, SqlReturnTypeInference typeInference) {
341-
return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(udfName), SqlParserPos.ZERO), typeInference,
342-
null, null, null, null);
343-
}
344271
}
345272
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.spark.transformers;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
import org.apache.calcite.sql.SqlCall;
12+
import org.apache.calcite.sql.SqlIdentifier;
13+
import org.apache.calcite.sql.SqlKind;
14+
import org.apache.calcite.sql.SqlNode;
15+
import org.apache.calcite.sql.SqlRowTypeNameSpec;
16+
import org.apache.calcite.sql.SqlRowTypeSpec;
17+
import org.apache.calcite.sql.parser.SqlParserPos;
18+
19+
import com.linkedin.coral.common.transformers.SqlCallTransformer;
20+
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;
21+
22+
23+
/**
24+
* This transformer transforms Coral IR function `CAST(ROW: RECORD_TYPE)` to Spark compatible function `named_struct`.
25+
* For example, the SqlCall `CAST(ROW(123, 'xyz') AS ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))`
26+
* will be transformed to `named_struct('abc', 123, 'def', 'xyz')`
27+
*/
28+
public class CastToNamedStructTransformer extends SqlCallTransformer {
29+
@Override
30+
protected boolean condition(SqlCall sqlCall) {
31+
if (sqlCall.getOperator().getKind() == SqlKind.CAST) {
32+
final SqlNode firstOperand = sqlCall.getOperandList().get(0);
33+
final SqlNode secondOperand = sqlCall.getOperandList().get(1);
34+
return firstOperand instanceof SqlCall && ((SqlCall) firstOperand).getOperator().getKind() == SqlKind.ROW
35+
&& secondOperand instanceof SqlRowTypeSpec;
36+
}
37+
return false;
38+
}
39+
40+
@Override
41+
protected SqlCall transform(SqlCall sqlCall) {
42+
List<SqlNode> newOperands = new ArrayList<>();
43+
final SqlCall rowCall = (SqlCall) sqlCall.getOperandList().get(0); // like `ROW(123, 'xyz')` in above example
44+
final SqlRowTypeSpec sqlRowTypeSpec = (SqlRowTypeSpec) sqlCall.getOperandList().get(1); // like `ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))` in above example
45+
for (int i = 0; i < rowCall.getOperandList().size(); ++i) {
46+
final String fieldName =
47+
((SqlRowTypeNameSpec) sqlRowTypeSpec.getTypeNameSpec()).getFieldNames().get(i).names.get(0);
48+
newOperands.add(new SqlIdentifier("'" + fieldName + "'", SqlParserPos.ZERO)); // need to single-quote the field name
49+
newOperands.add(rowCall.getOperandList().get(i));
50+
}
51+
return HiveNamedStructFunction.NAMED_STRUCT.createCall(sqlCall.getParserPosition(), newOperands);
52+
}
53+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/**
2+
* Copyright 2023 LinkedIn Corporation. All rights reserved.
3+
* Licensed under the BSD-2 Clause license.
4+
* See LICENSE in the project root for license information.
5+
*/
6+
package com.linkedin.coral.spark.transformers;
7+
8+
import java.net.URI;
9+
import java.util.List;
10+
import java.util.Set;
11+
12+
import org.apache.calcite.sql.SqlCall;
13+
import org.apache.calcite.sql.SqlNode;
14+
import org.apache.calcite.sql.SqlNumericLiteral;
15+
import org.apache.calcite.sql.SqlOperator;
16+
import org.apache.calcite.sql.parser.SqlParserPos;
17+
18+
import com.linkedin.coral.com.google.common.collect.ImmutableList;
19+
import com.linkedin.coral.common.transformers.SqlCallTransformer;
20+
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
21+
import com.linkedin.coral.spark.containers.SparkUDFInfo;
22+
23+
24+
/**
25+
* This transformer transforms `extract_union` to `coalesce_struct`.
26+
* Instead of leaving `extract_union` visible to Spark, since we adopted the new exploded struct schema (a.k.a struct_tr)
27+
* that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
28+
* we need to swap the reference of `extract_union` to a new UDF that is coalescing the difference between
29+
* struct_tr and struct_ex.
30+
* See {@link CoalesceStructUtility#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY} and its comments for more details.
31+
*
32+
* Check `CoralSparkTest#testUnionExtractUDF` for examples.
33+
*/
34+
public class ExtractUnionFunctionTransformer extends SqlCallTransformer {
35+
private static final String EXTRACT_UNION = "extract_union";
36+
private static final String COALESCE_STRUCT = "coalesce_struct";
37+
38+
private final Set<SparkUDFInfo> sparkUDFInfos;
39+
40+
public ExtractUnionFunctionTransformer(Set<SparkUDFInfo> sparkUDFInfos) {
41+
this.sparkUDFInfos = sparkUDFInfos;
42+
}
43+
44+
@Override
45+
protected boolean condition(SqlCall sqlCall) {
46+
return EXTRACT_UNION.equalsIgnoreCase(sqlCall.getOperator().getName());
47+
}
48+
49+
@Override
50+
protected SqlCall transform(SqlCall sqlCall) {
51+
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
52+
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
53+
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));
54+
final List<SqlNode> operandList = sqlCall.getOperandList();
55+
final SqlOperator coalesceStructFunction =
56+
createSqlOperator(COALESCE_STRUCT, CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY);
57+
if (operandList.size() == 1) {
58+
// one arg case: extract_union(field_name)
59+
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), operandList);
60+
} else if (operandList.size() == 2) {
61+
// two arg case: extract_union(field_name, ordinal)
62+
final int newOrdinal = ((SqlNumericLiteral) operandList.get(1)).getValueAs(Integer.class) + 1;
63+
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), ImmutableList.of(operandList.get(0),
64+
SqlNumericLiteral.createExactNumeric(String.valueOf(newOrdinal), SqlParserPos.ZERO)));
65+
} else {
66+
return sqlCall;
67+
}
68+
}
69+
}

0 commit comments

Comments
 (0)