diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/CoalesceStructUtility.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/CoalesceStructUtility.java index 981d0f05b..e96bfa406 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/CoalesceStructUtility.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/CoalesceStructUtility.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -22,7 +22,7 @@ /** * A utility class to coalesce the {@link RelDataType} of struct between Trino's representation and - * hive's extract_union UDF's representation on exploded union. + * Hive/Spark's extract_union UDF's representation on exploded union. * */ public class CoalesceStructUtility { diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java index c92e8d9e8..94b960b63 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java @@ -15,6 +15,8 @@ import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer; import com.linkedin.coral.common.transformers.SqlCallTransformers; import com.linkedin.coral.spark.containers.SparkUDFInfo; +import com.linkedin.coral.spark.transformers.CastToNamedStructTransformer; +import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer; import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer; import com.linkedin.coral.spark.transformers.TransportUDFTransformer; @@ -153,7 +155,13 @@ public CoralToSparkSqlCallConverter(Set sparkUDFInfos) { new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.CARDINALITY, 1, "size"), // Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above - new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos)); + new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos), + + // Transform `CAST(ROW: RECORD_TYPE)` to `named_struct` + new CastToNamedStructTransformer(), + + // Transform `extract_union` to `coalesce_struct` + new ExtractUnionFunctionTransformer(sparkUDFInfos)); } @Override diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java index 8ad4fadfa..a4345eee9 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java @@ -31,29 +31,19 @@ import org.apache.calcite.rel.logical.LogicalUnion; import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rel.type.RelRecordType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; import org.apache.calcite.rex.RexUtil; -import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.ArraySqlType; -import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlUserDefinedFunction; import com.linkedin.coral.com.google.common.collect.ImmutableList; -import com.linkedin.coral.com.google.common.collect.Lists; import com.linkedin.coral.common.functions.GenericProjectFunction; -import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility; -import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction; import com.linkedin.coral.spark.containers.SparkRelInfo; import com.linkedin.coral.spark.containers.SparkUDFInfo; import com.linkedin.coral.spark.utils.RelDataTypeToHiveTypeStringConverter; @@ -200,9 +190,8 @@ public RexNode visitCall(RexCall call) { RexCall updatedCall = (RexCall) super.visitCall(call); RexNode convertToNewNode = - convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertToNamedStruct(updatedCall).orElseGet( - () -> convertFuzzyUnionGenericProject(updatedCall).orElseGet(() -> swapExtractUnionFunction(updatedCall) - .orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall))))); + convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertFuzzyUnionGenericProject(updatedCall) + .orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall))); return convertToNewNode; } @@ -227,25 +216,6 @@ private Optional convertToZeroBasedArrayIndex(RexCall call) { return Optional.empty(); } - // Convert CAST(ROW: RECORD_TYPE) to named_struct - private Optional convertToNamedStruct(RexCall call) { - if (call.getOperator().equals(SqlStdOperatorTable.CAST)) { - RexNode operand = call.getOperands().get(0); - if (operand instanceof RexCall && ((RexCall) operand).getOperator().equals(SqlStdOperatorTable.ROW)) { - RelRecordType recordType = (RelRecordType) call.getType(); - List rowOperands = ((RexCall) operand).getOperands(); - List newOperands = new ArrayList<>(recordType.getFieldCount() * 2); - for (int i = 0; i < recordType.getFieldCount(); i += 1) { - RelDataTypeField dataTypeField = recordType.getFieldList().get(i); - newOperands.add(rexBuilder.makeLiteral(dataTypeField.getKey())); - newOperands.add(rexBuilder.makeCast(dataTypeField.getType(), rowOperands.get(i))); - } - return Optional.of(rexBuilder.makeCall(call.getType(), new HiveNamedStructFunction(), newOperands)); - } - } - return Optional.empty(); - } - /** * Add the schema to GenericProject in Fuzzy Union * @param call a given RexCall @@ -270,44 +240,6 @@ private Optional convertFuzzyUnionGenericProject(RexCall call) { return Optional.empty(); } - /** - * Instead of leaving extract_union visible to (Hive)Spark, since we adopted the new exploded struct schema( - * a.k.a struct_tr) that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR, - * we need to swap the reference of "extract_union" to a new UDF that is coalescing the difference between - * struct_tr and struct_ex. - * - * See com.linkedin.coral.common.functions.FunctionReturnTypes#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY - * and its comments for more details. - * - * @param call the original extract_union function call. - * @return A new {@link RexNode} replacing the original extract_union call. - */ - private Optional swapExtractUnionFunction(RexCall call) { - if (call.getOperator().getName().equalsIgnoreCase("extract_union")) { - // Only when there's a necessity to register coalesce_struct UDF - sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct", - ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")), - SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF)); - - // one arg case: extract_union(field_name) - if (call.getOperands().size() == 1) { - return Optional.of(rexBuilder.makeCall( - createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY), - call.getOperands())); - } - // two arg case: extract_union(field_name, ordinal) - else if (call.getOperands().size() == 2) { - int ordinal = ((RexLiteral) call.getOperands().get(1)).getValueAs(Integer.class) + 1; - List operandsCopy = Lists.newArrayList(call.getOperands()); - operandsCopy.set(1, rexBuilder.makeExactLiteral(new BigDecimal(ordinal))); - return Optional.of(rexBuilder.makeCall( - createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY), - operandsCopy)); - } - } - return Optional.empty(); - } - /** * Calcite entails the nullability of an expression by casting it to the correct nullable type. * However, for complex types like ARRAY (element non-nullable, but top-level nullable), @@ -336,10 +268,5 @@ private Optional removeCastToEnsureCorrectNullability(RexCall call) { } return Optional.empty(); } - - private static SqlOperator createUDF(String udfName, SqlReturnTypeInference typeInference) { - return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(udfName), SqlParserPos.ZERO), typeInference, - null, null, null, null); - } } } diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/CastToNamedStructTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/CastToNamedStructTransformer.java new file mode 100644 index 000000000..324b7b3a6 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/CastToNamedStructTransformer.java @@ -0,0 +1,53 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.transformers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlRowTypeNameSpec; +import org.apache.calcite.sql.SqlRowTypeSpec; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction; + + +/** + * This transformer transforms Coral IR function `CAST(ROW: RECORD_TYPE)` to Spark compatible function `named_struct`. + * For example, the SqlCall `CAST(ROW(123, 'xyz') AS ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))` + * will be transformed to `named_struct('abc', 123, 'def', 'xyz')` + */ +public class CastToNamedStructTransformer extends SqlCallTransformer { + @Override + protected boolean condition(SqlCall sqlCall) { + if (sqlCall.getOperator().getKind() == SqlKind.CAST) { + final SqlNode firstOperand = sqlCall.getOperandList().get(0); + final SqlNode secondOperand = sqlCall.getOperandList().get(1); + return firstOperand instanceof SqlCall && ((SqlCall) firstOperand).getOperator().getKind() == SqlKind.ROW + && secondOperand instanceof SqlRowTypeSpec; + } + return false; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List newOperands = new ArrayList<>(); + final SqlCall rowCall = (SqlCall) sqlCall.getOperandList().get(0); // like `ROW(123, 'xyz')` in above example + final SqlRowTypeSpec sqlRowTypeSpec = (SqlRowTypeSpec) sqlCall.getOperandList().get(1); // like `ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))` in above example + for (int i = 0; i < rowCall.getOperandList().size(); ++i) { + final String fieldName = + ((SqlRowTypeNameSpec) sqlRowTypeSpec.getTypeNameSpec()).getFieldNames().get(i).names.get(0); + newOperands.add(new SqlIdentifier("'" + fieldName + "'", SqlParserPos.ZERO)); // need to single-quote the field name + newOperands.add(rowCall.getOperandList().get(i)); + } + return HiveNamedStructFunction.NAMED_STRUCT.createCall(sqlCall.getParserPosition(), newOperands); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java new file mode 100644 index 000000000..27d6884b1 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java @@ -0,0 +1,69 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.transformers; + +import java.net.URI; +import java.util.List; +import java.util.Set; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNumericLiteral; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.com.google.common.collect.ImmutableList; +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility; +import com.linkedin.coral.spark.containers.SparkUDFInfo; + + +/** + * This transformer transforms `extract_union` to `coalesce_struct`. + * Instead of leaving `extract_union` visible to Spark, since we adopted the new exploded struct schema (a.k.a struct_tr) + * that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR, + * we need to swap the reference of `extract_union` to a new UDF that is coalescing the difference between + * struct_tr and struct_ex. + * See {@link CoalesceStructUtility#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY} and its comments for more details. + * + * Check `CoralSparkTest#testUnionExtractUDF` for examples. + */ +public class ExtractUnionFunctionTransformer extends SqlCallTransformer { + private static final String EXTRACT_UNION = "extract_union"; + private static final String COALESCE_STRUCT = "coalesce_struct"; + + private final Set sparkUDFInfos; + + public ExtractUnionFunctionTransformer(Set sparkUDFInfos) { + this.sparkUDFInfos = sparkUDFInfos; + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return EXTRACT_UNION.equalsIgnoreCase(sqlCall.getOperator().getName()); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct", + ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")), + SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF)); + final List operandList = sqlCall.getOperandList(); + final SqlOperator coalesceStructFunction = + createSqlOperator(COALESCE_STRUCT, CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY); + if (operandList.size() == 1) { + // one arg case: extract_union(field_name) + return coalesceStructFunction.createCall(sqlCall.getParserPosition(), operandList); + } else if (operandList.size() == 2) { + // two arg case: extract_union(field_name, ordinal) + final int newOrdinal = ((SqlNumericLiteral) operandList.get(1)).getValueAs(Integer.class) + 1; + return coalesceStructFunction.createCall(sqlCall.getParserPosition(), ImmutableList.of(operandList.get(0), + SqlNumericLiteral.createExactNumeric(String.valueOf(newOrdinal), SqlParserPos.ZERO))); + } else { + return sqlCall; + } + } +}