Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.common.transformers;

import javax.annotation.Nonnull;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParserPos;

import static com.linkedin.coral.common.calcite.CalciteUtil.*;


/**
* This class is a subclass of SqlCallTransformer which transforms a function operator on SqlNode layer
* if the signature of the operator to be transformed, including both the name and the number of operands,
* matches the target values in the condition function.
*/
public class OperatorBasedSqlCallTransformer extends SqlCallTransformer {
public final String fromOperatorName;
public final int numOperands;
public final SqlOperator targetOperator;

public OperatorBasedSqlCallTransformer(@Nonnull String fromOperatorName, int numOperands,
@Nonnull SqlOperator targetOperator) {
this.fromOperatorName = fromOperatorName;
this.numOperands = numOperands;
this.targetOperator = targetOperator;
}

public OperatorBasedSqlCallTransformer(@Nonnull SqlOperator coralOp, int numOperands, @Nonnull String trinoFuncName) {
this(coralOp.getName(), numOperands, createSqlUDF(trinoFuncName, coralOp.getReturnTypeInference()));
}

@Override
protected boolean condition(SqlCall sqlCall) {
return fromOperatorName.equalsIgnoreCase(sqlCall.getOperator().getName())
&& sqlCall.getOperandList().size() == numOperands;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
return createCall(targetOperator, sqlCall.getOperandList(), SqlParserPos.ZERO);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,48 @@
import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.apache.calcite.sql.validate.SqlValidator;

import com.linkedin.coral.common.functions.FunctionReturnTypes;


/**
* Abstract class for generic transformations on SqlCalls
*/
public abstract class SqlCallTransformer {
public static final SqlOperator TIMESTAMP_OPERATOR =
new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO), FunctionReturnTypes.TIMESTAMP, null,
OperandTypes.STRING, null, null) {
@Override
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
// for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)`
Preconditions.checkState(call.operandCount() == 1);
final SqlWriter.Frame frame = writer.startFunCall("CAST");
call.operand(0).unparse(writer, 0, 0);
writer.sep("AS");
writer.literal("TIMESTAMP");
writer.endFunCall(frame);
}
};

public static final SqlOperator DATE_OPERATOR = new SqlUserDefinedFunction(
new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null, OperandTypes.STRING, null, null);

private SqlValidator sqlValidator;
private final List<SqlSelect> topSelectNodes = new ArrayList<>();

Expand All @@ -32,9 +62,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
}

/**
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
*/
protected abstract boolean predicate(SqlCall sqlCall);
protected abstract boolean condition(SqlCall sqlCall);

/**
* Implementation of the transformation, returns the transformed SqlCall
Expand All @@ -49,7 +79,7 @@ public SqlCall apply(SqlCall sqlCall) {
if (sqlCall instanceof SqlSelect) {
this.topSelectNodes.add((SqlSelect) sqlCall);
}
if (predicate(sqlCall)) {
if (condition(sqlCall)) {
return transform(sqlCall);
} else {
return sqlCall;
Expand Down Expand Up @@ -96,4 +126,10 @@ protected RelDataType getRelDataType(SqlNode sqlNode) {
}
throw new RuntimeException("Failed to derive the RelDataType for SqlNode " + sqlNode);
}

public static SqlOperator createSqlUDF(String functionName, SqlReturnTypeInference typeInference) {
SqlIdentifier sqlIdentifier = new SqlIdentifier(
com.linkedin.coral.com.google.common.collect.ImmutableList.of(functionName), SqlParserPos.ZERO);
return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
/**
* Container for SqlCallTransformer
*/
public class SqlCallTransformers {
public final class SqlCallTransformers {
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;

SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) {
}

@Override
public boolean predicate(SqlCall sqlCall) {
public boolean condition(SqlCall sqlCall) {
if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) {
final SqlNode columnNode = sqlCall.getOperandList().get(0);
return getRelDataType(columnNode) instanceof ArraySqlType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
import com.linkedin.coral.com.google.common.collect.Multimap;
import com.linkedin.coral.common.functions.FunctionReturnTypes;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;
import static com.linkedin.coral.trino.rel2trino.UDFMapUtils.createUDF;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY;
import static org.apache.calcite.sql.type.ReturnTypes.explicit;
import static org.apache.calcite.sql.type.SqlTypeName.*;
Expand Down Expand Up @@ -248,13 +248,7 @@ public RexNode visitCall(RexCall call) {
}
}

final UDFTransformer transformer = CalciteTrinoUDFMap.getUDFTransformer(operatorName, call.operands.size());
if (transformer != null && shouldTransformOperator(operatorName)) {
return adjustReturnTypeWithCast(rexBuilder,
super.visitCall((RexCall) transformer.transformCall(rexBuilder, call.getOperands())));
}

if (operatorName.startsWith("com.linkedin") && transformer == null) {
if (operatorName.startsWith("com.linkedin")) {
return visitUnregisteredUDF(call);
}

Expand Down Expand Up @@ -301,8 +295,8 @@ private RexNode visitUnregisteredUDF(RexCall call) {

private Optional<RexNode> visitCollectListOrSetFunction(RexCall call) {
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
final SqlOperator arrayAgg = createUDF("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);
final SqlOperator arrayDistinct = createUDF("array_distinct", ReturnTypes.ARG0_NULLABLE);
final SqlOperator arrayAgg = SqlCallTransformer.createSqlUDF("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);
final SqlOperator arrayDistinct = SqlCallTransformer.createSqlUDF("array_distinct", ReturnTypes.ARG0_NULLABLE);
final String operatorName = call.getOperator().getName();
if (operatorName.equalsIgnoreCase("collect_list")) {
return Optional.of(rexBuilder.makeCall(arrayAgg, convertedOperands));
Expand All @@ -313,8 +307,8 @@ private Optional<RexNode> visitCollectListOrSetFunction(RexCall call) {

private Optional<RexNode> visitFromUnixtime(RexCall call) {
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
SqlOperator formatDatetime = createUDF("format_datetime", FunctionReturnTypes.STRING);
SqlOperator fromUnixtime = createUDF("from_unixtime", explicit(TIMESTAMP));
SqlOperator formatDatetime = SqlCallTransformer.createSqlUDF("format_datetime", FunctionReturnTypes.STRING);
SqlOperator fromUnixtime = SqlCallTransformer.createSqlUDF("from_unixtime", explicit(TIMESTAMP));
if (convertedOperands.size() == 1) {
return Optional
.of(rexBuilder.makeCall(formatDatetime, rexBuilder.makeCall(fromUnixtime, call.getOperands().get(0)),
Expand All @@ -338,13 +332,17 @@ private Optional<RexNode> visitFromUtcTimestampCall(RexCall call) {
// In below definitions we should use `TIMESTATMP WITH TIME ZONE`. As calcite is lacking
// this type we use `TIMESTAMP` instead. It does not have any practical implications as result syntax tree
// is not type-checked, and only used for generating output SQL for a view query.
SqlOperator trinoAtTimeZone = createUDF("at_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoWithTimeZone = createUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoToUnixTime = createUDF("to_unixtime", explicit(DOUBLE));
SqlOperator trinoAtTimeZone =
SqlCallTransformer.createSqlUDF("at_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoWithTimeZone =
SqlCallTransformer.createSqlUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoToUnixTime = SqlCallTransformer.createSqlUDF("to_unixtime", explicit(DOUBLE));
SqlOperator trinoFromUnixtimeNanos =
createUDF("from_unixtime_nanos", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoFromUnixTime = createUDF("from_unixtime", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoCanonicalizeHiveTimezoneId = createUDF("$canonicalize_hive_timezone_id", explicit(VARCHAR));
SqlCallTransformer.createSqlUDF("from_unixtime_nanos", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoFromUnixTime =
SqlCallTransformer.createSqlUDF("from_unixtime", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoCanonicalizeHiveTimezoneId =
SqlCallTransformer.createSqlUDF("$canonicalize_hive_timezone_id", explicit(VARCHAR));

RelDataType bigintType = typeFactory.createSqlType(BIGINT);
RelDataType doubleType = typeFactory.createSqlType(DOUBLE);
Expand Down Expand Up @@ -420,8 +418,9 @@ private Optional<RexNode> visitCast(RexCall call) {
// Trino does not allow for such conversion, but we can achieve the same behavior by first calling "to_unixtime"
// on the TIMESTAMP and then casting it to DECIMAL after.
if (call.getType().getSqlTypeName() == DECIMAL && leftOperand.getType().getSqlTypeName() == TIMESTAMP) {
SqlOperator trinoToUnixTime = createUDF("to_unixtime", explicit(DOUBLE));
SqlOperator trinoWithTimeZone = createUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
SqlOperator trinoToUnixTime = SqlCallTransformer.createSqlUDF("to_unixtime", explicit(DOUBLE));
SqlOperator trinoWithTimeZone =
SqlCallTransformer.createSqlUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */));
return Optional.of(rexBuilder.makeCast(call.getType(), rexBuilder.makeCall(trinoToUnixTime,
rexBuilder.makeCall(trinoWithTimeZone, leftOperand, rexBuilder.makeLiteral("UTC")))));
}
Expand All @@ -431,7 +430,7 @@ private Optional<RexNode> visitCast(RexCall call) {
if ((call.getType().getSqlTypeName() == VARCHAR || call.getType().getSqlTypeName() == CHAR)
&& (leftOperand.getType().getSqlTypeName() == VARBINARY
|| leftOperand.getType().getSqlTypeName() == BINARY)) {
SqlOperator fromUTF8 = createUDF("from_utf8", explicit(VARCHAR));
SqlOperator fromUTF8 = SqlCallTransformer.createSqlUDF("from_utf8", explicit(VARCHAR));
return Optional.of(rexBuilder.makeCall(fromUTF8, leftOperand));
}

Expand Down Expand Up @@ -482,10 +481,6 @@ private RexNode convertMapValueConstructor(RexBuilder rexBuilder, RexCall call)
return rexBuilder.makeCall(call.getOperator(), results);
}

private boolean shouldTransformOperator(String operatorName) {
return !("to_date".equalsIgnoreCase(operatorName) && configs.getOrDefault(AVOID_TRANSFORM_TO_DATE_UDF, false));
}

/**
* This method is to cast the converted call to the same return type in Hive with certain version.
* e.g. `datediff` in Hive returns int type, but the corresponding function `date_diff` in Trino returns bigint type
Expand All @@ -497,13 +492,14 @@ private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) {
}
final String lowercaseOperatorName = ((RexCall) call).getOperator().getName().toLowerCase(Locale.ROOT);
final ImmutableMap<String, RelDataType> operatorsToAdjust =
ImmutableMap.of("date_diff", typeFactory.createSqlType(INTEGER), "cardinality",
ImmutableMap.of("datediff", typeFactory.createSqlType(INTEGER), "cardinality",
typeFactory.createSqlType(INTEGER), "ceil", typeFactory.createSqlType(BIGINT), "ceiling",
typeFactory.createSqlType(BIGINT), "floor", typeFactory.createSqlType(BIGINT));
if (operatorsToAdjust.containsKey(lowercaseOperatorName)) {
return rexBuilder.makeCast(operatorsToAdjust.get(lowercaseOperatorName), call);
}
if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false) && lowercaseOperatorName.equals("date_add")) {
if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false)
&& (lowercaseOperatorName.equals("date_add") || lowercaseOperatorName.equals("date_sub"))) {
return rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), call);
}
return call;
Expand Down
Loading