-
Notifications
You must be signed in to change notification settings - Fork 196
Coral-Trino: Migrate function operator transformers defined in CalciteTrinoUDFMap from RelNode layer to SqlNode layer #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
6f0d543
Migrate standard UDF operator transformers based on JSON infra from R…
13ba070
address comments
630862a
fixing a typo of the class name
67e713d
address comments
5d50c45
Merge branch 'master' into yiqiangin/transformation-migration
05439cc
address comments
6a4f453
address comments
521cb6e
fix a typo
e7db21b
address comments
effb365
adding another constructor in LinkedInOperatorBasedSqlCallTransformer
8ea2c19
Simplify coral-trino transformations
wmoustafa d6947f0
fix the regression test failures
f8c4c2b
address comments
279e5d3
add link of a class in comments
440aa1c
fix the regression test failures
91771b8
Revert "fix the regression test failures"
4035bb5
fix the regression test failures
a05c0a2
rename a function and add some javadoc
ab8f649
add {@link} in javadoc of a function
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
262 changes: 262 additions & 0 deletions
262
...ava/com/linkedin/coral/common/transformers/SignatureBasedConditionSqlCallTransformer.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
/** | ||
* Copyright 2023 LinkedIn Corporation. All rights reserved. | ||
* Licensed under the BSD-2 Clause license. | ||
* See LICENSE in the project root for license information. | ||
*/ | ||
package com.linkedin.coral.common.transformers; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.regex.Pattern; | ||
import java.util.stream.Collectors; | ||
|
||
import javax.annotation.Nonnull; | ||
import javax.annotation.Nullable; | ||
|
||
import com.google.gson.JsonArray; | ||
import com.google.gson.JsonElement; | ||
import com.google.gson.JsonObject; | ||
import com.google.gson.JsonParser; | ||
import com.google.gson.JsonPrimitive; | ||
|
||
import org.apache.calcite.sql.SqlCall; | ||
import org.apache.calcite.sql.SqlIdentifier; | ||
import org.apache.calcite.sql.SqlNode; | ||
import org.apache.calcite.sql.SqlOperator; | ||
import org.apache.calcite.sql.SqlWriter; | ||
import org.apache.calcite.sql.fun.SqlStdOperatorTable; | ||
import org.apache.calcite.sql.parser.SqlParserPos; | ||
import org.apache.calcite.sql.type.OperandTypes; | ||
import org.apache.calcite.sql.type.ReturnTypes; | ||
import org.apache.calcite.sql.type.SqlOperandTypeChecker; | ||
import org.apache.calcite.sql.type.SqlReturnTypeInference; | ||
import org.apache.calcite.sql.validate.SqlUserDefinedFunction; | ||
|
||
import com.linkedin.coral.com.google.common.base.Preconditions; | ||
import com.linkedin.coral.common.functions.FunctionReturnTypes; | ||
|
||
import static com.linkedin.coral.common.calcite.CalciteUtil.*; | ||
|
||
|
||
/** | ||
* This class is a subclass of SqlCallTransformer which transforms a function operator on SqlNode layer | ||
* if the signature of the operator to be transformed, including both the name and the number of operands, | ||
* matches the target values in the condition function. | ||
*/ | ||
public class SignatureBasedConditionSqlCallTransformer extends SqlCallTransformer { | ||
yiqiangin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
private static final Map<String, SqlOperator> OP_MAP = new HashMap<>(); | ||
|
||
// Operators allowed in the transformation | ||
static { | ||
OP_MAP.put("+", SqlStdOperatorTable.PLUS); | ||
OP_MAP.put("-", SqlStdOperatorTable.MINUS); | ||
OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY); | ||
OP_MAP.put("/", SqlStdOperatorTable.DIVIDE); | ||
OP_MAP.put("^", SqlStdOperatorTable.POWER); | ||
OP_MAP.put("%", SqlStdOperatorTable.MOD); | ||
OP_MAP.put("date", new SqlUserDefinedFunction(new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null, | ||
OperandTypes.STRING, null, null)); | ||
OP_MAP.put("timestamp", new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO), | ||
FunctionReturnTypes.TIMESTAMP, null, OperandTypes.STRING, null, null) { | ||
@Override | ||
public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { | ||
// for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)` | ||
Preconditions.checkState(call.operandCount() == 1); | ||
final SqlWriter.Frame frame = writer.startFunCall("CAST"); | ||
call.operand(0).unparse(writer, 0, 0); | ||
writer.sep("AS"); | ||
writer.literal("TIMESTAMP"); | ||
writer.endFunCall(frame); | ||
} | ||
}); | ||
OP_MAP.put("hive_pattern_to_trino", | ||
new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO), | ||
FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null)); | ||
} | ||
|
||
public static final String OPERATOR = "op"; | ||
public static final String OPERANDS = "operands"; | ||
/** | ||
* For input node: | ||
* - input equals 0 refers to the result | ||
* - input great than 0 refers to the index of source operand (starting from 1) | ||
*/ | ||
public static final String INPUT = "input"; | ||
public static final String VALUE = "value"; | ||
public static final String REGEX = "regex"; | ||
public static final String NAME = "name"; | ||
|
||
public final String fromOperatorName; | ||
public final int numOperands; | ||
public final SqlOperator targetOperator; | ||
public List<JsonObject> operandTransformers; | ||
public JsonObject resultTransformer; | ||
public List<JsonObject> operatorTransformers; | ||
|
||
public SignatureBasedConditionSqlCallTransformer(@Nonnull String fromOperatorName, int numOperands, | ||
@Nonnull SqlOperator targetOperator, @Nullable String operandTransformers, @Nullable String resultTransformer, | ||
@Nullable String operatorTransformers) { | ||
this.fromOperatorName = fromOperatorName; | ||
this.numOperands = numOperands; | ||
this.targetOperator = targetOperator; | ||
if (operandTransformers != null) { | ||
this.operandTransformers = parseJsonObjectsFromString(operandTransformers); | ||
} | ||
if (resultTransformer != null) { | ||
this.resultTransformer = new JsonParser().parse(resultTransformer).getAsJsonObject(); | ||
} | ||
if (operatorTransformers != null) { | ||
this.operatorTransformers = parseJsonObjectsFromString(operatorTransformers); | ||
} | ||
} | ||
|
||
@Override | ||
protected boolean condition(SqlCall sqlCall) { | ||
return fromOperatorName.equalsIgnoreCase(sqlCall.getOperator().getName()) | ||
&& sqlCall.getOperandList().size() == numOperands; | ||
} | ||
|
||
@Override | ||
protected SqlCall transform(SqlCall sqlCall) { | ||
List<SqlNode> sourceOperands = sqlCall.getOperandList(); | ||
final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands); | ||
if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) { | ||
String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(",")); | ||
throw new IllegalArgumentException( | ||
String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)", | ||
fromOperatorName, operands)); | ||
} | ||
final List<SqlNode> newOperands = transformOperands(sourceOperands); | ||
final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO); | ||
return (SqlCall) transformResult(newCall, sourceOperands); | ||
} | ||
|
||
private List<SqlNode> transformOperands(List<SqlNode> sourceOperands) { | ||
if (operandTransformers == null) { | ||
return sourceOperands; | ||
} | ||
final List<SqlNode> sources = new ArrayList<>(); | ||
// Add a dummy expression for input 0 | ||
sources.add(null); | ||
sources.addAll(sourceOperands); | ||
final List<SqlNode> results = new ArrayList<>(); | ||
for (JsonObject operandTransformer : operandTransformers) { | ||
results.add(transformExpression(operandTransformer, sources)); | ||
} | ||
return results; | ||
} | ||
|
||
private SqlNode transformResult(SqlNode result, List<SqlNode> sourceOperands) { | ||
if (resultTransformer == null) { | ||
return result; | ||
} | ||
final List<SqlNode> sources = new ArrayList<>(); | ||
// Result will be input 0 | ||
sources.add(result); | ||
sources.addAll(sourceOperands); | ||
return transformExpression(resultTransformer, sources); | ||
} | ||
|
||
/** | ||
* Performs a single transformer. | ||
*/ | ||
private SqlNode transformExpression(JsonObject transformer, List<SqlNode> sourceOperands) { | ||
if (transformer.get(OPERATOR) != null) { | ||
final List<SqlNode> inputOperands = new ArrayList<>(); | ||
for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) { | ||
if (inputOperand.isJsonObject()) { | ||
inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands)); | ||
} | ||
} | ||
final String operatorName = transformer.get(OPERATOR).getAsString(); | ||
final SqlOperator op = OP_MAP.get(operatorName); | ||
if (op == null) { | ||
throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation"); | ||
} | ||
return createCall(op, inputOperands, SqlParserPos.ZERO); | ||
} | ||
if (transformer.get(INPUT) != null) { | ||
int index = transformer.get(INPUT).getAsInt(); | ||
if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) { | ||
throw new IllegalArgumentException( | ||
"Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size()); | ||
} | ||
return sourceOperands.get(index); | ||
} | ||
final JsonElement value = transformer.get(VALUE); | ||
if (value == null) { | ||
throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value"); | ||
} | ||
if (!value.isJsonPrimitive()) { | ||
throw new IllegalArgumentException("Value should be of primitive type: " + value); | ||
} | ||
|
||
final JsonPrimitive primitive = value.getAsJsonPrimitive(); | ||
if (primitive.isString()) { | ||
return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO); | ||
} | ||
if (primitive.isBoolean()) { | ||
return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO); | ||
} | ||
if (primitive.isNumber()) { | ||
return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO); | ||
} | ||
|
||
throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive); | ||
} | ||
|
||
/** | ||
* Returns a SqlOperator with a function name based on the value of the source operands. | ||
*/ | ||
private SqlOperator transformTargetOperator(SqlOperator operator, List<SqlNode> sourceOperands) { | ||
if (operatorTransformers == null) { | ||
return operator; | ||
} | ||
|
||
for (JsonObject operatorTransformer : operatorTransformers) { | ||
if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) { | ||
throw new IllegalArgumentException( | ||
"JSON node for target operator transformer must have a matcher, input and name"); | ||
} | ||
// We use the same convention as operand and result transformers. | ||
// Therefore, we start source index values at index 1 instead of index 0. | ||
// Acceptable index values are set to be [1, size] | ||
int index = operatorTransformer.get(INPUT).getAsInt() - 1; | ||
if (index < 0 || index >= sourceOperands.size()) { | ||
throw new IllegalArgumentException( | ||
String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size())); | ||
} | ||
String functionName = operatorTransformer.get(NAME).getAsString(); | ||
if (functionName.isEmpty()) { | ||
throw new IllegalArgumentException("JSON node for transformation must have a non-empty name"); | ||
} | ||
String matcher = operatorTransformer.get(REGEX).getAsString(); | ||
|
||
if (Pattern.matches(matcher, sourceOperands.get(index).toString())) { | ||
return createOperator(functionName, operator.getReturnTypeInference(), null); | ||
} | ||
} | ||
return operator; | ||
} | ||
|
||
/** | ||
* Creates an ArrayList of JsonObjects from a string input. | ||
* The input string must be a serialized JSON array. | ||
*/ | ||
private static List<JsonObject> parseJsonObjectsFromString(String s) { | ||
List<JsonObject> objects = new ArrayList<>(); | ||
JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray(); | ||
for (JsonElement object : transformerArray) { | ||
objects.add(object.getAsJsonObject()); | ||
} | ||
return objects; | ||
} | ||
|
||
public static SqlOperator createOperator(String functionName, SqlReturnTypeInference returnTypeInference, | ||
SqlOperandTypeChecker operandTypeChecker) { | ||
return new SqlUserDefinedFunction(new SqlIdentifier(functionName, SqlParserPos.ZERO), returnTypeInference, null, | ||
operandTypeChecker, null, null); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.