Skip to content

Commit

Permalink
Coral-Spark: Migrate some operator transformers from RelNode layer to…
Browse files Browse the repository at this point in the history
… SqlNode layer
  • Loading branch information
ljfgem committed Feb 6, 2023
1 parent 7c23b8d commit 394979a
Show file tree
Hide file tree
Showing 21 changed files with 591 additions and 530 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.common.transformers;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.parser.SqlParserPos;


/**
* This class is a subclass of {@link SourceOperatorMatchSqlCallTransformer} which renames the source operator
*/
public class OperatorRenameSqlCallTransformer extends SourceOperatorMatchSqlCallTransformer {
private final String targetOpName;

public OperatorRenameSqlCallTransformer(String sourceOpName, int numOperands, String targetOpName) {
super(sourceOpName, numOperands);
this.targetOpName = targetOpName;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
return createSqlOperatorOfFunction(targetOpName, sqlCall.getOperator().getReturnTypeInference())
.createCall(new SqlNodeList(sqlCall.getOperandList(), SqlParserPos.ZERO));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.common.transformers;

import org.apache.calcite.sql.SqlCall;


/**
* This class is a subclass of {@link SqlCallTransformer} which transforms a function operator on SqlNode layer
* if the signature of the operator to be transformed, including both the name and the number of operands,
* matches the target values in the condition function.
*/
public abstract class SourceOperatorMatchSqlCallTransformer extends SqlCallTransformer {
protected final String sourceOpName;
protected final int numOperands;

public SourceOperatorMatchSqlCallTransformer(String sourceOpName, int numOperands) {
this.sourceOpName = sourceOpName;
this.numOperands = numOperands;
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sourceOpName.equalsIgnoreCase(sqlCall.getOperator().getName())
&& sqlCall.getOperandList().size() == numOperands;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
import java.util.ArrayList;
import java.util.List;

import com.google.common.collect.ImmutableList;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.apache.calcite.sql.validate.SqlValidator;


Expand All @@ -32,9 +39,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
}

/**
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
*/
protected abstract boolean predicate(SqlCall sqlCall);
protected abstract boolean condition(SqlCall sqlCall);

/**
* Implementation of the transformation, returns the transformed SqlCall
Expand All @@ -49,7 +56,7 @@ public SqlCall apply(SqlCall sqlCall) {
if (sqlCall instanceof SqlSelect) {
this.topSelectNodes.add((SqlSelect) sqlCall);
}
if (predicate(sqlCall)) {
if (condition(sqlCall)) {
return transform(sqlCall);
} else {
return sqlCall;
Expand Down Expand Up @@ -96,4 +103,9 @@ protected RelDataType getRelDataType(SqlNode sqlNode) {
}
throw new RuntimeException("Failed to derive the RelDataType for SqlNode " + sqlNode);
}

protected static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) {
SqlIdentifier sqlIdentifier = new SqlIdentifier(ImmutableList.of(functionName), SqlParserPos.ZERO);
return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
/**
* Container for SqlCallTransformer
*/
public class SqlCallTransformers {
public final class SqlCallTransformers {
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;

SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package com.linkedin.coral.transformers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -25,6 +26,7 @@
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlCall;
Expand All @@ -43,6 +45,7 @@
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.com.google.common.collect.ImmutableMap;
import com.linkedin.coral.common.functions.CoralSqlUnnestOperator;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;


/**
Expand Down Expand Up @@ -345,4 +348,41 @@ private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rig

return SqlStdOperatorTable.AS.createCall(POS, asOperands);
}

/**
* Override this method to handle the conversion for RelNode `f(x).y.z` where `f` is an UDF, which
* returns a struct containing field `y`, `y` is also a struct containing field `z`.
*
* Calcite will convert this RelNode to a SqlIdentifier directly (check
* {@link org.apache.calcite.rel.rel2sql.SqlImplementor.Context#toSql(RexProgram, RexNode)}),
* which is not aligned with our expectation since we want to apply transformations on `f(x)` with
* {@link com.linkedin.coral.common.transformers.SqlCallTransformer}. Therefore, we override this
* method to convert `f(x)` to SqlCall, `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT}
* and `y.z` to SqlIdentifier.
*/
@Override
public Context aliasContext(Map<String, RelDataType> aliases, boolean qualified) {
return new AliasContext(INSTANCE, aliases, qualified) {
@Override
public SqlNode toSql(RexProgram program, RexNode rex) {
if (rex.getKind() == SqlKind.FIELD_ACCESS) {
final List<String> accessNames = new ArrayList<>();
RexNode referencedExpr = rex;
// Use the loop to get the top-level struct (`f(x)` in the example above),
// and store the accessed field names ([`z`, `y`] in the example above, needs to be reversed)
while (referencedExpr.getKind() == SqlKind.FIELD_ACCESS) {
accessNames.add(((RexFieldAccess) referencedExpr).getField().getName());
referencedExpr = ((RexFieldAccess) referencedExpr).getReferenceExpr();
}
if (referencedExpr.getKind() == SqlKind.OTHER_FUNCTION) {
SqlNode functionCall = toSql(program, referencedExpr);
Collections.reverse(accessNames);
return FunctionFieldReferenceOperator.DOT.createCall(SqlParserPos.ZERO, functionCall,
new SqlIdentifier(String.join(".", accessNames), POS));
}
}
return super.toSql(program, rex);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) {
}

@Override
public boolean predicate(SqlCall sqlCall) {
public boolean condition(SqlCall sqlCall) {
if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) {
final SqlNode columnNode = sqlCall.getOperandList().get(0);
return getRelDataType(columnNode) instanceof ArraySqlType;
Expand Down

This file was deleted.

26 changes: 16 additions & 10 deletions coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package com.linkedin.coral.spark;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.avro.Schema;
Expand All @@ -16,6 +17,7 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlSelect;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.spark.containers.SparkRelInfo;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.dialect.SparkSqlDialect;
Expand Down Expand Up @@ -63,11 +65,11 @@ private CoralSpark(List<String> baseTables, List<SparkUDFInfo> sparkUDFInfoList,
*/
public static CoralSpark create(RelNode irRelNode) {
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
String sparkSQL = constructSparkSQL(sparkRelNode);
String sparkSQL = constructSparkSQL(sparkRelNode, sparkUDFInfos);
List<String> baseTables = constructBaseTables(sparkRelNode);
List<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfoList();
return new CoralSpark(baseTables, sparkUDFInfos, sparkSQL);
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL);
}

/**
Expand All @@ -86,11 +88,11 @@ public static CoralSpark create(RelNode irRelNode, Schema schema) {

private static CoralSpark createWithAlias(RelNode irRelNode, List<String> aliases) {
SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode);
Set<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfos();
RelNode sparkRelNode = sparkRelInfo.getSparkRelNode();
String sparkSQL = constructSparkSQLWithExplicitAlias(sparkRelNode, aliases);
String sparkSQL = constructSparkSQLWithExplicitAlias(sparkRelNode, aliases, sparkUDFInfos);
List<String> baseTables = constructBaseTables(sparkRelNode);
List<SparkUDFInfo> sparkUDFInfos = sparkRelInfo.getSparkUDFInfoList();
return new CoralSpark(baseTables, sparkUDFInfos, sparkSQL);
return new CoralSpark(baseTables, ImmutableList.copyOf(sparkUDFInfos), sparkSQL);
}

/**
Expand All @@ -105,21 +107,25 @@ private static CoralSpark createWithAlias(RelNode irRelNode, List<String> aliase
*
* @param sparkRelNode A Spark compatible RelNode
*
* @param sparkUDFInfos A set of Spark UDF information objects
* @return SQL String in Spark SQL dialect which is 'completely expanded'
*/
private static String constructSparkSQL(RelNode sparkRelNode) {
private static String constructSparkSQL(RelNode sparkRelNode, Set<SparkUDFInfo> sparkUDFInfos) {
CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter();
SqlNode coralSqlNode = rel2sql.convert(sparkRelNode);
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter());
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter())
.accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));
SqlNode rewrittenSparkSqlNode = sparkSqlNode.accept(new SparkSqlRewriter());
return rewrittenSparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql();
}

private static String constructSparkSQLWithExplicitAlias(RelNode sparkRelNode, List<String> aliases) {
private static String constructSparkSQLWithExplicitAlias(RelNode sparkRelNode, List<String> aliases,
Set<SparkUDFInfo> sparkUDFInfos) {
CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter();
// Create temporary objects r and rewritten to make debugging easier
SqlNode coralSqlNode = rel2sql.convert(sparkRelNode);
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter());
SqlNode sparkSqlNode = coralSqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter())
.accept(new CoralToSparkSqlCallConverter(sparkUDFInfos));

SqlNode rewritten = sparkSqlNode.accept(new SparkSqlRewriter());
// Use a second pass visit to add explicit alias names,
Expand Down
Loading

0 comments on commit 394979a

Please sign in to comment.