Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import com.linkedin.coral.common.transformers.OperatorRenameSqlCallTransformer;
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.transformers.CastToNamedStructTransformer;
import com.linkedin.coral.spark.transformers.ExtractUnionFunctionTransformer;
import com.linkedin.coral.spark.transformers.FallBackToLinkedInHiveUDFTransformer;
import com.linkedin.coral.spark.transformers.TransportUDFTransformer;

Expand Down Expand Up @@ -153,7 +155,13 @@ public CoralToSparkSqlCallConverter(Set<SparkUDFInfo> sparkUDFInfos) {
new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.CARDINALITY, 1, "size"),

// Fall back to the original Hive UDF defined in StaticHiveFunctionRegistry after failing to apply transformers above
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos));
new FallBackToLinkedInHiveUDFTransformer(sparkUDFInfos),

// Transform `CAST(ROW: RECORD_TYPE)` to `named_struct`
new CastToNamedStructTransformer(),

// Transform `extract_union` to `coalesce_struct`
new ExtractUnionFunctionTransformer(sparkUDFInfos));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,19 @@
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.com.google.common.collect.Lists;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;
import com.linkedin.coral.spark.containers.SparkRelInfo;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.utils.RelDataTypeToHiveTypeStringConverter;
Expand Down Expand Up @@ -200,9 +190,8 @@ public RexNode visitCall(RexCall call) {
RexCall updatedCall = (RexCall) super.visitCall(call);

RexNode convertToNewNode =
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertToNamedStruct(updatedCall).orElseGet(
() -> convertFuzzyUnionGenericProject(updatedCall).orElseGet(() -> swapExtractUnionFunction(updatedCall)
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)))));
convertToZeroBasedArrayIndex(updatedCall).orElseGet(() -> convertFuzzyUnionGenericProject(updatedCall)
.orElseGet(() -> removeCastToEnsureCorrectNullability(updatedCall).orElse(updatedCall)));

return convertToNewNode;
}
Expand All @@ -227,25 +216,6 @@ private Optional<RexNode> convertToZeroBasedArrayIndex(RexCall call) {
return Optional.empty();
}

// Convert CAST(ROW: RECORD_TYPE) to named_struct
private Optional<RexNode> convertToNamedStruct(RexCall call) {
if (call.getOperator().equals(SqlStdOperatorTable.CAST)) {
RexNode operand = call.getOperands().get(0);
if (operand instanceof RexCall && ((RexCall) operand).getOperator().equals(SqlStdOperatorTable.ROW)) {
RelRecordType recordType = (RelRecordType) call.getType();
List<RexNode> rowOperands = ((RexCall) operand).getOperands();
List<RexNode> newOperands = new ArrayList<>(recordType.getFieldCount() * 2);
for (int i = 0; i < recordType.getFieldCount(); i += 1) {
RelDataTypeField dataTypeField = recordType.getFieldList().get(i);
newOperands.add(rexBuilder.makeLiteral(dataTypeField.getKey()));
newOperands.add(rexBuilder.makeCast(dataTypeField.getType(), rowOperands.get(i)));
}
return Optional.of(rexBuilder.makeCall(call.getType(), new HiveNamedStructFunction(), newOperands));
}
}
return Optional.empty();
}

/**
* Add the schema to GenericProject in Fuzzy Union
* @param call a given RexCall
Expand All @@ -270,44 +240,6 @@ private Optional<RexNode> convertFuzzyUnionGenericProject(RexCall call) {
return Optional.empty();
}

/**
* Instead of leaving extract_union visible to (Hive)Spark, since we adopted the new exploded struct schema(
* a.k.a struct_tr) that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
* we need to swap the reference of "extract_union" to a new UDF that is coalescing the difference between
* struct_tr and struct_ex.
*
* See com.linkedin.coral.common.functions.FunctionReturnTypes#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY
* and its comments for more details.
*
* @param call the original extract_union function call.
* @return A new {@link RexNode} replacing the original extract_union call.
*/
private Optional<RexNode> swapExtractUnionFunction(RexCall call) {
if (call.getOperator().getName().equalsIgnoreCase("extract_union")) {
// Only when there's a necessity to register coalesce_struct UDF
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));

// one arg case: extract_union(field_name)
if (call.getOperands().size() == 1) {
return Optional.of(rexBuilder.makeCall(
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
call.getOperands()));
}
// two arg case: extract_union(field_name, ordinal)
else if (call.getOperands().size() == 2) {
int ordinal = ((RexLiteral) call.getOperands().get(1)).getValueAs(Integer.class) + 1;
List<RexNode> operandsCopy = Lists.newArrayList(call.getOperands());
operandsCopy.set(1, rexBuilder.makeExactLiteral(new BigDecimal(ordinal)));
return Optional.of(rexBuilder.makeCall(
createUDF("coalesce_struct", CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY),
operandsCopy));
}
}
return Optional.empty();
}

/**
* Calcite entails the nullability of an expression by casting it to the correct nullable type.
* However, for complex types like ARRAY<STRING NOT NULL> (element non-nullable, but top-level nullable),
Expand Down Expand Up @@ -336,10 +268,5 @@ private Optional<RexNode> removeCastToEnsureCorrectNullability(RexCall call) {
}
return Optional.empty();
}

private static SqlOperator createUDF(String udfName, SqlReturnTypeInference typeInference) {
return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(udfName), SqlParserPos.ZERO), typeInference,
null, null, null, null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.spark.transformers;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlRowTypeNameSpec;
import org.apache.calcite.sql.SqlRowTypeSpec;
import org.apache.calcite.sql.parser.SqlParserPos;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
* This transformer transforms `CAST(ROW: RECORD_TYPE)` to `named_struct` in Spark.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: transforms Coral IR function ..... to Spark compatible operator named_struct....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's suitable to call it Coral IR function, I think it's a function in CoralSqlNode converted from Coral IR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the path:
source SQL -> sourceSqlNode1 -> CoralSqlNode1 -> CoralRelNode -> CoralSqlNode2 -> TargetLangSqlNode -> target SQL

if an operator is present in CoralSqlNode1, CoralRelNode, CoralSqlNode2 and we transform this operator when we go to LanguageSqlNode, such operators can be called Coral IR functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per the current status of our code base -
hive SQL -> HiveSqlNode (same as CoralSqlNode1) -> CoralRelNode -> CoralSqlNode2 -> SparkSqlNode -> Spark SQL
I think this operator qualifies as Coral Operator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification, do you mean an operator is a Coral IR function if it appears in CoralSqlNode1 / CoralRelNode / CoralSqlNode2 or CoralSqlNode1 & CoralRelNode & CoralSqlNode2?
Not sure because it doesn't appear in CoralSqlNode1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for a corresponding input hive SQL, what is the operator we use in the SqlNode representation? how do we represent the named_struct on LHS?

but generally yes, an operator present in CoralSqlNode1 would be present in CoralRelNode and then generated back in CoralSqlNode2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still represented as named_struct in CoralSqlNode1, and Calcite converts it to CAST(ROW: RECORD_TYPE) in RelNode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case, could you investigate if it's possible to add a lightweight override in CoralRelNodeToCoralSqlNode to such that the coralSqlNode2 generated uses the same operator named_struct as coralSqlNode1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #357 for tracking.

* For example, the SqlCall `CAST(ROW(123, 'xyz') AS ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))`
* will be transformed to `named_struct('abc', 123, 'def', 'xyz')`
*/
public class CastToNamedStructTransformer extends SqlCallTransformer {
@Override
protected boolean condition(SqlCall sqlCall) {
if (sqlCall.getOperator().getKind() == SqlKind.CAST) {
final SqlNode firstOperand = sqlCall.getOperandList().get(0);
final SqlNode secondOperand = sqlCall.getOperandList().get(1);
return firstOperand instanceof SqlCall && ((SqlCall) firstOperand).getOperator().getKind() == SqlKind.ROW
&& secondOperand instanceof SqlRowTypeSpec;
}
return false;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> newOperands = new ArrayList<>();
final SqlCall rowCall = (SqlCall) sqlCall.getOperandList().get(0); // like `ROW(123, 'xyz')` in above example
final SqlRowTypeSpec sqlRowTypeSpec = (SqlRowTypeSpec) sqlCall.getOperandList().get(1); // like `ROW(`abc` INTEGER, `def` CHAR(3) CHARACTER SET `ISO-8859-1`))` in above example
for (int i = 0; i < rowCall.getOperandList().size(); ++i) {
final String fieldName =
((SqlRowTypeNameSpec) sqlRowTypeSpec.getTypeNameSpec()).getFieldNames().get(i).names.get(0);
newOperands.add(new SqlIdentifier("'" + fieldName + "'", SqlParserPos.ZERO)); // need to single-quote the field name
newOperands.add(rowCall.getOperandList().get(i));
}
return HiveNamedStructFunction.NAMED_STRUCT.createCall(sqlCall.getParserPosition(), newOperands);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.spark.transformers;

import java.net.URI;
import java.util.List;
import java.util.Set;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParserPos;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.CoalesceStructUtility;
import com.linkedin.coral.spark.containers.SparkUDFInfo;


/**
* This transformer transforms `extract_union` to `coalesce_struct`.
* Instead of leaving `extract_union` visible to Spark, since we adopted the new exploded struct schema (a.k.a struct_tr)
* that is different from extract_union's output (a.k.a struct_ex) to interpret union in Coral IR,
* we need to swap the reference of `extract_union` to a new UDF that is coalescing the difference between
* struct_tr and struct_ex.
* See {@link CoalesceStructUtility#COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY} and its comments for more details.
*
* Check `CoralSparkTest#testUnionExtractUDF` for examples.
*/
public class ExtractUnionFunctionTransformer extends SqlCallTransformer {
private static final String EXTRACT_UNION = "extract_union";
private static final String COALESCE_STRUCT = "coalesce_struct";

private final Set<SparkUDFInfo> sparkUDFInfos;

public ExtractUnionFunctionTransformer(Set<SparkUDFInfo> sparkUDFInfos) {
this.sparkUDFInfos = sparkUDFInfos;
}

@Override
protected boolean condition(SqlCall sqlCall) {
return EXTRACT_UNION.equalsIgnoreCase(sqlCall.getOperator().getName());
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
sparkUDFInfos.add(new SparkUDFInfo("com.linkedin.coalescestruct.GenericUDFCoalesceStruct", "coalesce_struct",
ImmutableList.of(URI.create("ivy://com.linkedin.coalesce-struct:coalesce-struct-impl:+")),
SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));
final List<SqlNode> operandList = sqlCall.getOperandList();
final SqlOperator coalesceStructFunction =
createSqlOperator(COALESCE_STRUCT, CoalesceStructUtility.COALESCE_STRUCT_FUNCTION_RETURN_STRATEGY);
if (operandList.size() == 1) {
// one arg case: extract_union(field_name)
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), operandList);
} else if (operandList.size() == 2) {
// two arg case: extract_union(field_name, ordinal)
final int newOrdinal = ((SqlNumericLiteral) operandList.get(1)).getValueAs(Integer.class) + 1;
return coalesceStructFunction.createCall(sqlCall.getParserPosition(), ImmutableList.of(operandList.get(0),
SqlNumericLiteral.createExactNumeric(String.valueOf(newOrdinal), SqlParserPos.ZERO)));
} else {
return sqlCall;
}
}
}