Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand All @@ -22,7 +22,7 @@

/**
* A utility class to coalesce the {@link RelDataType} of struct between Trino's representation and
* hive's extract_union UDF's representation on exploded union.
* Hive/Spark's extract_union UDF's representation on exploded union.
*
*/
public class CoalesceStructUtility {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer;
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.transformers.CastToNamedStructTransformer;
import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer;
import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer;
import com.linkedin.coral.spark.transformers.TransportUDFTransformer;

Expand Down Expand Up @@ -153,7 +155,13 @@ public CoralToSparkSqlCallConverter(Set<SparkUDFInfo> sparkUDFInfos) {
new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.CARDINALITY, 1, "size"),

// Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos));
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos),

// Transform `CAST(ROW: RECORD_TYPE)` to `named_struct`
new CastToNamedStructTransformer(),

// Transform `extract_union` to `coalesce_struct`
new ExtractUnionFunctionTransformer(sparkUDFInfos));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,19 @@
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.com.google.common.collect.Lists;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;
import com.linkedin.coral.spark.containers.SparkRelInfo;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.utils.RelDataTypeToHiveTypeStringConverter;
Expand Down Expand Up @@ -200,9 +190,8 @@ public RexNode visitCall(RexCall call) {
RexCall updatedCall = (RexCall) super.visitCall(call);

RexNode convertToNewNode =
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertToNamedStruct(updatedCall).orElseGet(
() -> convertFuzzyUnionGenericProject(updatedCall).orElseGet(() -> swapExtractUnionFunction(updatedCall)
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)))));
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertFuzzyUnionGenericProject(updatedCall)
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)));

return convertToNewNode;
}
Expand All @@ -227,25 +216,6 @@ private Optional<RexNode> convertToZeroBasedArrayIndex(RexCall call) {
return Optional.empty();
}

// Convert CAST(ROW: RECORD_TYPE) to named_struct
private Optional<RexNode> convertToNamedStruct(RexCall call) {
if (call.getOperator().equals(SqlStdOperatorTable.CAST)) {
RexNode operand = call.getOperands().get(0);
if (operand instanceof RexCall && ((RexCall) operand).getOperator().equals(SqlStdOperatorTable.ROW)) {
RelRecordType recordType = (RelRecordType) call.getType();
List<RexNode> rowOperands = ((RexCall) operand).getOperands();
List<RexNode> newOperands = new ArrayList<>(recordType.getFieldCount() * 2);
for (int i = 0; i < recordType.getFieldCount(); i += 1) {
RelDataTypeField dataTypeField = recordType.getFieldList().get(i);
newOperands.add(rexBuilder.makeLiteral(dataTypeField.getKey()));
newOperands.add(rexBuilder.makeCast(dataTypeField.getType(), rowOperands.get(i)));
}
return Optional.of(rexBuilder.makeCall(call.getType(), new HiveNamedStructFunction(), newOperands));
}
}
return Optional.empty();
}

/**
* Add the schema to GenericProject in Fuzzy Union
* @param call a given RexCall
Expand All @@ -270,44 +240,6 @@ private Optional<RexNode> convertFuzzyUnionGenericProject(RexCall call) {
return Optional.empty();
}

/**
* Instead of leaving extract_union visible to (Hive)Spark, since we adopted the new exploded struct schema(
* a.k.a struct_tr) that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
* we need to swap the reference of "extract_union" to a new UDF that is coalescing the difference between
* struct_tr and struct_ex.
*
* See com.linkedin.coral.common.functions.FunctionReturnTypes#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY
* and its comments for more details.
*
* @param call the original extract_union function call.
* @return A new {@link RexNode} replacing the original extract_union call.
*/
private Optional<RexNode> swapExtractUnionFunction(RexCall call) {
if (call.getOperator().getName().equalsIgnoreCase("extract_union")) {
// Only when there's a necessity to register coalesce_struct UDF
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));

// one arg case: extract_union(field_name)
if (call.getOperands().size() == 1) {
return Optional.of(rexBuilder.makeCall(
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
call.getOperands()));
}
// two arg case: extract_union(field_name, ordinal)
else if (call.getOperands().size() == 2) {
int ordinal = ((RexLiteral) call.getOperands().get(1)).getValueAs(Integer.class) + 1;
List<RexNode> operandsCopy = Lists.newArrayList(call.getOperands());
operandsCopy.set(1, rexBuilder.makeExactLiteral(new BigDecimal(ordinal)));
return Optional.of(rexBuilder.makeCall(
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
operandsCopy));
}
}
return Optional.empty();
}

/**
* Calcite entails the nullability of an expression by casting it to the correct nullable type.
* However, for complex types like ARRAY<STRING NOT NULL> (element non-nullable, but top-level nullable),
Expand Down Expand Up @@ -336,10 +268,5 @@ private Optional<RexNode> removeCastToEnsureCorrectNullability(RexCall call) {
}
return Optional.empty();
}

private static SqlOperator createUDF(String udfName, SqlReturnTypeInference typeInference) {
return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(udfName), SqlParserPos.ZERO), typeInference,
null, null, null, null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.spark.transformers;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlRowTypeNameSpec;
import org.apache.calcite.sql.SqlRowTypeSpec;
import org.apache.calcite.sql.parser.SqlParserPos;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
* This transformer transforms Coral IR function `CAST(ROW: RECORD_TYPE)` to Spark compatible function `named_struct`.
* For example, the SqlCall `CAST(ROW(123, 'xyz') AS ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))`
* will be transformed to `named_struct('abc', 123, 'def', 'xyz')`
*/
public class CastToNamedStructTransformer extends SqlCallTransformer {
@Override
protected boolean condition(SqlCall sqlCall) {
if (sqlCall.getOperator().getKind() == SqlKind.CAST) {
final SqlNode firstOperand = sqlCall.getOperandList().get(0);
final SqlNode secondOperand = sqlCall.getOperandList().get(1);
return firstOperand instanceof SqlCall && ((SqlCall) firstOperand).getOperator().getKind() == SqlKind.ROW
&& secondOperand instanceof SqlRowTypeSpec;
}
return false;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> newOperands = new ArrayList<>();
final SqlCall rowCall = (SqlCall) sqlCall.getOperandList().get(0); // like `ROW(123, 'xyz')` in above example
final SqlRowTypeSpec sqlRowTypeSpec = (SqlRowTypeSpec) sqlCall.getOperandList().get(1); // like `ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))` in above example
for (int i = 0; i < rowCall.getOperandList().size(); ++i) {
final String fieldName =
((SqlRowTypeNameSpec) sqlRowTypeSpec.getTypeNameSpec()).getFieldNames().get(i).names.get(0);
newOperands.add(new SqlIdentifier("'" + fieldName + "'", SqlParserPos.ZERO)); // need to single-quote the field name
newOperands.add(rowCall.getOperandList().get(i));
}
return HiveNamedStructFunction.NAMED_STRUCT.createCall(sqlCall.getParserPosition(), newOperands);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.spark.transformers;

import java.net.URI;
import java.util.List;
import java.util.Set;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParserPos;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
import com.linkedin.coral.spark.containers.SparkUDFInfo;


/**
* This transformer transforms `extract_union` to `coalesce_struct`.
* Instead of leaving `extract_union` visible to Spark, since we adopted the new exploded struct schema (a.k.a struct_tr)
* that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
* we need to swap the reference of `extract_union` to a new UDF that is coalescing the difference between
* struct_tr and struct_ex.
* See {@link CoalesceStructUtility#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY} and its comments for more details.
*
* Check `CoralSparkTest#testUnionExtractUDF` for examples.
*/
public class ExtractUnionFunctionTransformer extends SqlCallTransformer {
private static final String EXTRACT_UNION = "extract_union";
private static final String COALESCE_STRUCT = "coalesce_struct";

private final Set<SparkUDFInfo> sparkUDFInfos;

public ExtractUnionFunctionTransformer(Set<SparkUDFInfo> sparkUDFInfos) {
this.sparkUDFInfos = sparkUDFInfos;
}

@Override
protected boolean condition(SqlCall sqlCall) {
return EXTRACT_UNION.equalsIgnoreCase(sqlCall.getOperator().getName());
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));
final List<SqlNode> operandList = sqlCall.getOperandList();
final SqlOperator coalesceStructFunction =
createSqlOperator(COALESCE_STRUCT, CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY);
if (operandList.size() == 1) {
// one arg case: extract_union(field_name)
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), operandList);
} else if (operandList.size() == 2) {
// two arg case: extract_union(field_name, ordinal)
final int newOrdinal = ((SqlNumericLiteral) operandList.get(1)).getValueAs(Integer.class) + 1;
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), ImmutableList.of(operandList.get(0),
SqlNumericLiteral.createExactNumeric(String.valueOf(newOrdinal), SqlParserPos.ZERO)));
} else {
return sqlCall;
}
}
}