Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coral-Trino: Migrate function operator transformers defined in CalciteTrinoUDFMap from RelNode layer to SqlNode layer #349

Merged
merged 19 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.common.transformers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.base.Preconditions;
import com.linkedin.coral.common.functions.FunctionReturnTypes;

import static com.linkedin.coral.common.calcite.CalciteUtil.*;


/**
* This class is a subclass of SqlCallTransformer which transforms a function operator on SqlNode layer
* if the signature of the operator to be transformed, including both the name and the number of operands,
* matches the target values in the condition function.
*/
public class SignatureBasedConditionSqlCallTransformer extends SqlCallTransformer {
yiqiangin marked this conversation as resolved.
Show resolved Hide resolved
yiqiangin marked this conversation as resolved.
Show resolved Hide resolved
private static final Map<String, SqlOperator> OP_MAP = new HashMap<>();

// Operators allowed in the transformation
static {
OP_MAP.put("+", SqlStdOperatorTable.PLUS);
OP_MAP.put("-", SqlStdOperatorTable.MINUS);
OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY);
OP_MAP.put("/", SqlStdOperatorTable.DIVIDE);
OP_MAP.put("^", SqlStdOperatorTable.POWER);
OP_MAP.put("%", SqlStdOperatorTable.MOD);
OP_MAP.put("date", new SqlUserDefinedFunction(new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null,
OperandTypes.STRING, null, null));
OP_MAP.put("timestamp", new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO),
FunctionReturnTypes.TIMESTAMP, null, OperandTypes.STRING, null, null) {
@Override
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) {
// for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)`
Preconditions.checkState(call.operandCount() == 1);
final SqlWriter.Frame frame = writer.startFunCall("CAST");
call.operand(0).unparse(writer, 0, 0);
writer.sep("AS");
writer.literal("TIMESTAMP");
writer.endFunCall(frame);
}
});
OP_MAP.put("hive_pattern_to_trino",
new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO),
FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null));
}

public static final String OPERATOR = "op";
public static final String OPERANDS = "operands";
/**
* For input node:
* - input equals 0 refers to the result
* - input great than 0 refers to the index of source operand (starting from 1)
*/
public static final String INPUT = "input";
public static final String VALUE = "value";
public static final String REGEX = "regex";
public static final String NAME = "name";

public final String fromOperatorName;
public final int numOperands;
public final SqlOperator targetOperator;
public List<JsonObject> operandTransformers;
public JsonObject resultTransformer;
public List<JsonObject> operatorTransformers;

public SignatureBasedConditionSqlCallTransformer(@Nonnull String fromOperatorName, int numOperands,
@Nonnull SqlOperator targetOperator, @Nullable String operandTransformers, @Nullable String resultTransformer,
@Nullable String operatorTransformers) {
this.fromOperatorName = fromOperatorName;
this.numOperands = numOperands;
this.targetOperator = targetOperator;
if (operandTransformers != null) {
this.operandTransformers = parseJsonObjectsFromString(operandTransformers);
}
if (resultTransformer != null) {
this.resultTransformer = new JsonParser().parse(resultTransformer).getAsJsonObject();
}
if (operatorTransformers != null) {
this.operatorTransformers = parseJsonObjectsFromString(operatorTransformers);
}
}

@Override
protected boolean condition(SqlCall sqlCall) {
return fromOperatorName.equalsIgnoreCase(sqlCall.getOperator().getName())
&& sqlCall.getOperandList().size() == numOperands;
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> sourceOperands = sqlCall.getOperandList();
final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands);
if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) {
String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(","));
throw new IllegalArgumentException(
String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)",
fromOperatorName, operands));
}
final List<SqlNode> newOperands = transformOperands(sourceOperands);
final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO);
return (SqlCall) transformResult(newCall, sourceOperands);
}

private List<SqlNode> transformOperands(List<SqlNode> sourceOperands) {
if (operandTransformers == null) {
return sourceOperands;
}
final List<SqlNode> sources = new ArrayList<>();
// Add a dummy expression for input 0
sources.add(null);
sources.addAll(sourceOperands);
final List<SqlNode> results = new ArrayList<>();
for (JsonObject operandTransformer : operandTransformers) {
results.add(transformExpression(operandTransformer, sources));
}
return results;
}

private SqlNode transformResult(SqlNode result, List<SqlNode> sourceOperands) {
if (resultTransformer == null) {
return result;
}
final List<SqlNode> sources = new ArrayList<>();
// Result will be input 0
sources.add(result);
sources.addAll(sourceOperands);
return transformExpression(resultTransformer, sources);
}

/**
* Performs a single transformer.
*/
private SqlNode transformExpression(JsonObject transformer, List<SqlNode> sourceOperands) {
if (transformer.get(OPERATOR) != null) {
final List<SqlNode> inputOperands = new ArrayList<>();
for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) {
if (inputOperand.isJsonObject()) {
inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands));
}
}
final String operatorName = transformer.get(OPERATOR).getAsString();
final SqlOperator op = OP_MAP.get(operatorName);
if (op == null) {
throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation");
}
return createCall(op, inputOperands, SqlParserPos.ZERO);
}
if (transformer.get(INPUT) != null) {
int index = transformer.get(INPUT).getAsInt();
if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) {
throw new IllegalArgumentException(
"Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size());
}
return sourceOperands.get(index);
}
final JsonElement value = transformer.get(VALUE);
if (value == null) {
throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value");
}
if (!value.isJsonPrimitive()) {
throw new IllegalArgumentException("Value should be of primitive type: " + value);
}

final JsonPrimitive primitive = value.getAsJsonPrimitive();
if (primitive.isString()) {
return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO);
}
if (primitive.isBoolean()) {
return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO);
}
if (primitive.isNumber()) {
return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO);
}

throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive);
}

/**
* Returns a SqlOperator with a function name based on the value of the source operands.
*/
private SqlOperator transformTargetOperator(SqlOperator operator, List<SqlNode> sourceOperands) {
if (operatorTransformers == null) {
return operator;
}

for (JsonObject operatorTransformer : operatorTransformers) {
if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) {
throw new IllegalArgumentException(
"JSON node for target operator transformer must have a matcher, input and name");
}
// We use the same convention as operand and result transformers.
// Therefore, we start source index values at index 1 instead of index 0.
// Acceptable index values are set to be [1, size]
int index = operatorTransformer.get(INPUT).getAsInt() - 1;
if (index < 0 || index >= sourceOperands.size()) {
throw new IllegalArgumentException(
String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size()));
}
String functionName = operatorTransformer.get(NAME).getAsString();
if (functionName.isEmpty()) {
throw new IllegalArgumentException("JSON node for transformation must have a non-empty name");
}
String matcher = operatorTransformer.get(REGEX).getAsString();

if (Pattern.matches(matcher, sourceOperands.get(index).toString())) {
return createOperator(functionName, operator.getReturnTypeInference(), null);
}
}
return operator;
}

/**
* Creates an ArrayList of JsonObjects from a string input.
* The input string must be a serialized JSON array.
*/
private static List<JsonObject> parseJsonObjectsFromString(String s) {
List<JsonObject> objects = new ArrayList<>();
JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray();
for (JsonElement object : transformerArray) {
objects.add(object.getAsJsonObject());
}
return objects;
}

public static SqlOperator createOperator(String functionName, SqlReturnTypeInference returnTypeInference,
SqlOperandTypeChecker operandTypeChecker) {
return new SqlUserDefinedFunction(new SqlIdentifier(functionName, SqlParserPos.ZERO), returnTypeInference, null,
operandTypeChecker, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) {
}

/**
* Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not
* Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not
*/
protected abstract boolean predicate(SqlCall sqlCall);
protected abstract boolean condition(SqlCall sqlCall);

/**
* Implementation of the transformation, returns the transformed SqlCall
Expand All @@ -49,7 +49,7 @@ public SqlCall apply(SqlCall sqlCall) {
if (sqlCall instanceof SqlSelect) {
this.topSelectNodes.add((SqlSelect) sqlCall);
}
if (predicate(sqlCall)) {
if (condition(sqlCall)) {
return transform(sqlCall);
} else {
return sqlCall;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
public class SqlCallTransformers {
yiqiangin marked this conversation as resolved.
Show resolved Hide resolved
private final ImmutableList<SqlCallTransformer> sqlCallTransformers;

SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {
public SqlCallTransformers(ImmutableList<SqlCallTransformer> sqlCallTransformers) {
yiqiangin marked this conversation as resolved.
Show resolved Hide resolved
this.sqlCallTransformers = sqlCallTransformers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) {
}

@Override
public boolean predicate(SqlCall sqlCall) {
public boolean condition(SqlCall sqlCall) {
if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) {
final SqlNode columnNode = sqlCall.getOperandList().get(0);
return getRelDataType(columnNode) instanceof ArraySqlType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,7 @@ public RexNode visitCall(RexCall call) {
}
}

final UDFTransformer transformer = CalciteTrinoUDFMap.getUDFTransformer(operatorName, call.operands.size());
if (transformer != null && shouldTransformOperator(operatorName)) {
return adjustReturnTypeWithCast(rexBuilder,
super.visitCall((RexCall) transformer.transformCall(rexBuilder, call.getOperands())));
}

if (operatorName.startsWith("com.linkedin") && transformer == null) {
if (operatorName.startsWith("com.linkedin")) {
return visitUnregisteredUDF(call);
}

Expand Down Expand Up @@ -482,10 +476,6 @@ private RexNode convertMapValueConstructor(RexBuilder rexBuilder, RexCall call)
return rexBuilder.makeCall(call.getOperator(), results);
}

private boolean shouldTransformOperator(String operatorName) {
return !("to_date".equalsIgnoreCase(operatorName) && configs.getOrDefault(AVOID_TRANSFORM_TO_DATE_UDF, false));
}

/**
* This method is to cast the converted call to the same return type in Hive with certain version.
* e.g. `datediff` in Hive returns int type, but the corresponding function `date_diff` in Trino returns bigint type
Expand All @@ -497,13 +487,14 @@ private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) {
}
final String lowercaseOperatorName = ((RexCall) call).getOperator().getName().toLowerCase(Locale.ROOT);
final ImmutableMap<String, RelDataType> operatorsToAdjust =
ImmutableMap.of("date_diff", typeFactory.createSqlType(INTEGER), "cardinality",
ImmutableMap.of("datediff", typeFactory.createSqlType(INTEGER), "cardinality",
typeFactory.createSqlType(INTEGER), "ceil", typeFactory.createSqlType(BIGINT), "ceiling",
typeFactory.createSqlType(BIGINT), "floor", typeFactory.createSqlType(BIGINT));
if (operatorsToAdjust.containsKey(lowercaseOperatorName)) {
return rexBuilder.makeCast(operatorsToAdjust.get(lowercaseOperatorName), call);
}
if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false) && lowercaseOperatorName.equals("date_add")) {
if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false)
&& (lowercaseOperatorName.equals("date_add") || lowercaseOperatorName.equals("date_sub"))) {
return rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), call);
}
return call;
Expand Down
Loading