diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/OperatorBasedSqlCallTransformer.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/OperatorBasedSqlCallTransformer.java new file mode 100644 index 000000000..48ab9e0f0 --- /dev/null +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/OperatorBasedSqlCallTransformer.java @@ -0,0 +1,49 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.common.transformers; + +import javax.annotation.Nonnull; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; + + +/** + * This class is a subclass of SqlCallTransformer which transforms a function operator on SqlNode layer + * if the signature of the operator to be transformed, including both the name and the number of operands, + * matches the target values in the condition function. + */ +public class OperatorBasedSqlCallTransformer extends SqlCallTransformer { + public final String fromOperatorName; + public final int numOperands; + public final SqlOperator targetOperator; + + public OperatorBasedSqlCallTransformer(@Nonnull String fromOperatorName, int numOperands, + @Nonnull SqlOperator targetOperator) { + this.fromOperatorName = fromOperatorName; + this.numOperands = numOperands; + this.targetOperator = targetOperator; + } + + public OperatorBasedSqlCallTransformer(@Nonnull SqlOperator coralOp, int numOperands, @Nonnull String trinoFuncName) { + this(coralOp.getName(), numOperands, createSqlUDF(trinoFuncName, coralOp.getReturnTypeInference())); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return fromOperatorName.equalsIgnoreCase(sqlCall.getOperator().getName()) + && sqlCall.getOperandList().size() == numOperands; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + return createCall(targetOperator, sqlCall.getOperandList(), SqlParserPos.ZERO); + } + +} diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java index d98b5eacd..9cfcde4b6 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformer.java @@ -8,18 +8,48 @@ import java.util.ArrayList; import java.util.List; +import com.google.common.base.Preconditions; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.calcite.sql.validate.SqlValidator; +import com.linkedin.coral.common.functions.FunctionReturnTypes; + /** * Abstract class for generic transformations on SqlCalls */ public abstract class SqlCallTransformer { + public static final SqlOperator TIMESTAMP_OPERATOR = + new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO), FunctionReturnTypes.TIMESTAMP, null, + OperandTypes.STRING, null, null) { + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + // for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)` + Preconditions.checkState(call.operandCount() == 1); + final SqlWriter.Frame frame = writer.startFunCall("CAST"); + call.operand(0).unparse(writer, 0, 0); + writer.sep("AS"); + writer.literal("TIMESTAMP"); + writer.endFunCall(frame); + } + }; + + public static final SqlOperator DATE_OPERATOR = new SqlUserDefinedFunction( + new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null, OperandTypes.STRING, null, null); + private SqlValidator sqlValidator; private final List topSelectNodes = new ArrayList<>(); @@ -32,9 +62,9 @@ public SqlCallTransformer(SqlValidator sqlValidator) { } /** - * Predicate of the transformer, it’s used to determine if the SqlCall should be transformed or not + * Condition of the transformer, it’s used to determine if the SqlCall should be transformed or not */ - protected abstract boolean predicate(SqlCall sqlCall); + protected abstract boolean condition(SqlCall sqlCall); /** * Implementation of the transformation, returns the transformed SqlCall @@ -49,7 +79,7 @@ public SqlCall apply(SqlCall sqlCall) { if (sqlCall instanceof SqlSelect) { this.topSelectNodes.add((SqlSelect) sqlCall); } - if (predicate(sqlCall)) { + if (condition(sqlCall)) { return transform(sqlCall); } else { return sqlCall; @@ -96,4 +126,10 @@ protected RelDataType getRelDataType(SqlNode sqlNode) { } throw new RuntimeException("Failed to derive the RelDataType for SqlNode " + sqlNode); } + + public static SqlOperator createSqlUDF(String functionName, SqlReturnTypeInference typeInference) { + SqlIdentifier sqlIdentifier = new SqlIdentifier( + com.linkedin.coral.com.google.common.collect.ImmutableList.of(functionName), SqlParserPos.ZERO); + return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null); + } } diff --git a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java index 07671294b..e438fd1aa 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/transformers/SqlCallTransformers.java @@ -15,7 +15,7 @@ /** * Container for SqlCallTransformer */ -public class SqlCallTransformers { +public final class SqlCallTransformers { private final ImmutableList sqlCallTransformers; SqlCallTransformers(ImmutableList sqlCallTransformers) { diff --git a/coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java b/coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java index 18434fff4..dbf265de0 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java +++ b/coral-hive/src/main/java/com/linkedin/coral/transformers/ShiftArrayIndexTransformer.java @@ -31,7 +31,7 @@ public ShiftArrayIndexTransformer(SqlValidator sqlValidator) { } @Override - public boolean predicate(SqlCall sqlCall) { + public boolean condition(SqlCall sqlCall) { if (ITEM_OPERATOR.equalsIgnoreCase(sqlCall.getOperator().getName())) { final SqlNode columnNode = sqlCall.getOperandList().get(0); return getRelDataType(columnNode) instanceof ArraySqlType; diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java index f81af2670..3ab4339f2 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java @@ -54,10 +54,10 @@ import com.linkedin.coral.com.google.common.collect.Multimap; import com.linkedin.coral.common.functions.FunctionReturnTypes; import com.linkedin.coral.common.functions.GenericProjectFunction; +import com.linkedin.coral.common.transformers.SqlCallTransformer; import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter; import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*; -import static com.linkedin.coral.trino.rel2trino.UDFMapUtils.createUDF; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY; import static org.apache.calcite.sql.type.ReturnTypes.explicit; import static org.apache.calcite.sql.type.SqlTypeName.*; @@ -248,13 +248,7 @@ public RexNode visitCall(RexCall call) { } } - final UDFTransformer transformer = CalciteTrinoUDFMap.getUDFTransformer(operatorName, call.operands.size()); - if (transformer != null && shouldTransformOperator(operatorName)) { - return adjustReturnTypeWithCast(rexBuilder, - super.visitCall((RexCall) transformer.transformCall(rexBuilder, call.getOperands()))); - } - - if (operatorName.startsWith("com.linkedin") && transformer == null) { + if (operatorName.startsWith("com.linkedin")) { return visitUnregisteredUDF(call); } @@ -301,8 +295,8 @@ private RexNode visitUnregisteredUDF(RexCall call) { private Optional visitCollectListOrSetFunction(RexCall call) { List convertedOperands = visitList(call.getOperands(), (boolean[]) null); - final SqlOperator arrayAgg = createUDF("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE); - final SqlOperator arrayDistinct = createUDF("array_distinct", ReturnTypes.ARG0_NULLABLE); + final SqlOperator arrayAgg = SqlCallTransformer.createSqlUDF("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE); + final SqlOperator arrayDistinct = SqlCallTransformer.createSqlUDF("array_distinct", ReturnTypes.ARG0_NULLABLE); final String operatorName = call.getOperator().getName(); if (operatorName.equalsIgnoreCase("collect_list")) { return Optional.of(rexBuilder.makeCall(arrayAgg, convertedOperands)); @@ -313,8 +307,8 @@ private Optional visitCollectListOrSetFunction(RexCall call) { private Optional visitFromUnixtime(RexCall call) { List convertedOperands = visitList(call.getOperands(), (boolean[]) null); - SqlOperator formatDatetime = createUDF("format_datetime", FunctionReturnTypes.STRING); - SqlOperator fromUnixtime = createUDF("from_unixtime", explicit(TIMESTAMP)); + SqlOperator formatDatetime = SqlCallTransformer.createSqlUDF("format_datetime", FunctionReturnTypes.STRING); + SqlOperator fromUnixtime = SqlCallTransformer.createSqlUDF("from_unixtime", explicit(TIMESTAMP)); if (convertedOperands.size() == 1) { return Optional .of(rexBuilder.makeCall(formatDatetime, rexBuilder.makeCall(fromUnixtime, call.getOperands().get(0)), @@ -338,13 +332,17 @@ private Optional visitFromUtcTimestampCall(RexCall call) { // In below definitions we should use `TIMESTATMP WITH TIME ZONE`. As calcite is lacking // this type we use `TIMESTAMP` instead. It does not have any practical implications as result syntax tree // is not type-checked, and only used for generating output SQL for a view query. - SqlOperator trinoAtTimeZone = createUDF("at_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); - SqlOperator trinoWithTimeZone = createUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); - SqlOperator trinoToUnixTime = createUDF("to_unixtime", explicit(DOUBLE)); + SqlOperator trinoAtTimeZone = + SqlCallTransformer.createSqlUDF("at_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); + SqlOperator trinoWithTimeZone = + SqlCallTransformer.createSqlUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); + SqlOperator trinoToUnixTime = SqlCallTransformer.createSqlUDF("to_unixtime", explicit(DOUBLE)); SqlOperator trinoFromUnixtimeNanos = - createUDF("from_unixtime_nanos", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); - SqlOperator trinoFromUnixTime = createUDF("from_unixtime", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); - SqlOperator trinoCanonicalizeHiveTimezoneId = createUDF("$canonicalize_hive_timezone_id", explicit(VARCHAR)); + SqlCallTransformer.createSqlUDF("from_unixtime_nanos", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); + SqlOperator trinoFromUnixTime = + SqlCallTransformer.createSqlUDF("from_unixtime", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); + SqlOperator trinoCanonicalizeHiveTimezoneId = + SqlCallTransformer.createSqlUDF("$canonicalize_hive_timezone_id", explicit(VARCHAR)); RelDataType bigintType = typeFactory.createSqlType(BIGINT); RelDataType doubleType = typeFactory.createSqlType(DOUBLE); @@ -420,8 +418,9 @@ private Optional visitCast(RexCall call) { // Trino does not allow for such conversion, but we can achieve the same behavior by first calling "to_unixtime" // on the TIMESTAMP and then casting it to DECIMAL after. if (call.getType().getSqlTypeName() == DECIMAL && leftOperand.getType().getSqlTypeName() == TIMESTAMP) { - SqlOperator trinoToUnixTime = createUDF("to_unixtime", explicit(DOUBLE)); - SqlOperator trinoWithTimeZone = createUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); + SqlOperator trinoToUnixTime = SqlCallTransformer.createSqlUDF("to_unixtime", explicit(DOUBLE)); + SqlOperator trinoWithTimeZone = + SqlCallTransformer.createSqlUDF("with_timezone", explicit(TIMESTAMP /* should be WITH TIME ZONE */)); return Optional.of(rexBuilder.makeCast(call.getType(), rexBuilder.makeCall(trinoToUnixTime, rexBuilder.makeCall(trinoWithTimeZone, leftOperand, rexBuilder.makeLiteral("UTC"))))); } @@ -431,7 +430,7 @@ private Optional visitCast(RexCall call) { if ((call.getType().getSqlTypeName() == VARCHAR || call.getType().getSqlTypeName() == CHAR) && (leftOperand.getType().getSqlTypeName() == VARBINARY || leftOperand.getType().getSqlTypeName() == BINARY)) { - SqlOperator fromUTF8 = createUDF("from_utf8", explicit(VARCHAR)); + SqlOperator fromUTF8 = SqlCallTransformer.createSqlUDF("from_utf8", explicit(VARCHAR)); return Optional.of(rexBuilder.makeCall(fromUTF8, leftOperand)); } @@ -482,10 +481,6 @@ private RexNode convertMapValueConstructor(RexBuilder rexBuilder, RexCall call) return rexBuilder.makeCall(call.getOperator(), results); } - private boolean shouldTransformOperator(String operatorName) { - return !("to_date".equalsIgnoreCase(operatorName) && configs.getOrDefault(AVOID_TRANSFORM_TO_DATE_UDF, false)); - } - /** * This method is to cast the converted call to the same return type in Hive with certain version. * e.g. `datediff` in Hive returns int type, but the corresponding function `date_diff` in Trino returns bigint type @@ -497,13 +492,14 @@ private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) { } final String lowercaseOperatorName = ((RexCall) call).getOperator().getName().toLowerCase(Locale.ROOT); final ImmutableMap operatorsToAdjust = - ImmutableMap.of("date_diff", typeFactory.createSqlType(INTEGER), "cardinality", + ImmutableMap.of("datediff", typeFactory.createSqlType(INTEGER), "cardinality", typeFactory.createSqlType(INTEGER), "ceil", typeFactory.createSqlType(BIGINT), "ceiling", typeFactory.createSqlType(BIGINT), "floor", typeFactory.createSqlType(BIGINT)); if (operatorsToAdjust.containsKey(lowercaseOperatorName)) { return rexBuilder.makeCast(operatorsToAdjust.get(lowercaseOperatorName), call); } - if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false) && lowercaseOperatorName.equals("date_add")) { + if (configs.getOrDefault(CAST_DATEADD_TO_STRING, false) + && (lowercaseOperatorName.equals("date_add") || lowercaseOperatorName.equals("date_sub"))) { return rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), call); } return call; diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CalciteTrinoUDFMap.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CalciteTrinoUDFMap.java deleted file mode 100644 index 1877689d4..000000000 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CalciteTrinoUDFMap.java +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.trino.rel2trino; - -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; - -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; - -import com.linkedin.coral.com.google.common.base.CaseFormat; -import com.linkedin.coral.com.google.common.base.Converter; -import com.linkedin.coral.com.google.common.collect.ImmutableMultimap; -import com.linkedin.coral.common.functions.Function; -import com.linkedin.coral.hive.hive2rel.functions.HiveRLikeOperator; -import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; -import com.linkedin.coral.trino.rel2trino.functions.TrinoElementAtFunction; - -import static com.linkedin.coral.trino.rel2trino.UDFMapUtils.*; - - -public class CalciteTrinoUDFMap { - private CalciteTrinoUDFMap() { - } - - private static final Map UDF_MAP = new HashMap<>(); - private static final StaticHiveFunctionRegistry HIVE_REGISTRY = new StaticHiveFunctionRegistry(); - static { - // conditional functions - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("nvl"), 2, "coalesce"); - // Array and map functions - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.ITEM, 2, TrinoElementAtFunction.INSTANCE); - - // Math Functions - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.RAND, 0, "RANDOM"); - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.RAND, 1, "RANDOM", "[]", null); - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.RAND_INTEGER, 1, "RANDOM"); - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.RAND_INTEGER, 2, "RANDOM", "[{\"input\":2}]", null); - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.TRUNCATE, 2, "TRUNCATE", - "[{\"op\":\"*\",\"operands\":[{\"input\":1},{\"op\":\"^\",\"operands\":[{\"value\":10},{\"input\":2}]}]}]", - "{\"op\":\"/\",\"operands\":[{\"input\":0},{\"op\":\"^\",\"operands\":[{\"value\":10},{\"input\":2}]}]}"); - - // String Functions - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.SUBSTRING, 2, "SUBSTR"); - createUDFMapEntry(UDF_MAP, SqlStdOperatorTable.SUBSTRING, 3, "SUBSTR"); - createUDFMapEntry(UDF_MAP, HiveRLikeOperator.RLIKE, 2, "REGEXP_LIKE"); - createUDFMapEntry(UDF_MAP, HiveRLikeOperator.REGEXP, 2, "REGEXP_LIKE"); - - // JSON Functions - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("get_json_object"), 2, "json_extract"); - - // map various hive functions - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("pmod"), 2, "mod", - "[{\"op\":\"+\",\"operands\":[{\"op\":\"%\",\"operands\":[{\"input\":1},{\"input\":2}]},{\"input\":2}]},{\"input\":2}]", - null); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("base64"), 1, "to_base64"); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("unbase64"), 1, "from_base64"); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("hex"), 1, "to_hex"); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("unhex"), 1, "from_hex"); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("array_contains"), 2, "contains"); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("regexp_extract"), 3, "regexp_extract", - "[{\"input\": 1}, {\"op\": \"hive_pattern_to_trino\", \"operands\":[{\"input\": 2}]}, {\"input\": 3}]", null); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("instr"), 2, "strpos"); - createRuntimeUDFMapEntry(UDF_MAP, hiveToCalciteOp("decode"), 2, - "[{\"regex\":\"(?i)('utf-8')\", \"input\":2, \"name\":\"from_utf8\"}]", "[{\"input\":1}]", null); - - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("to_date"), 1, "date", - "[{\"op\": \"timestamp\", \"operands\":[{\"input\": 1}]}]", null); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("date_add"), 2, "date_add", "[{\"value\": 'day'}, {\"input\": 2}, " - + "{\"op\": \"date\", \"operands\":[{\"op\": \"timestamp\", \"operands\":[{\"input\": 1}]}]}]", null); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("date_sub"), 2, "date_add", - "[{\"value\": 'day'}, " + "{\"op\": \"*\", \"operands\":[{\"input\": 2}, {\"value\": -1}]}, " - + "{\"op\": \"date\", \"operands\":[{\"op\": \"timestamp\", \"operands\":[{\"input\": 1}]}]}]", - null); - createUDFMapEntry(UDF_MAP, hiveToCalciteOp("datediff"), 2, "date_diff", - "[{\"value\": 'day'}, {\"op\": \"date\", \"operands\":[{\"op\": \"timestamp\", \"operands\":[{\"input\": 2}]}]}, " - + "{\"op\": \"date\", \"operands\":[{\"op\": \"timestamp\", \"operands\":[{\"input\": 1}]}]}]", - null); - - // DALI functions - // Most "com.linkedin..." UDFs follow convention of having UDF names mapped from camel-cased name to snake-cased name. - // For example: For class name IsGuestMemberId, the conventional udf name would be is_guest_member_id. - // While this convention fits most UDFs it doesn't fit all. With the following mapping we override the conventional - // UDF name mapping behavior to a hardcoded one. - // For example instead of UserAgentParser getting mapped to user_agent_parser, we mapped it here to useragentparser - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.dali.udf.watbotcrawlerlookup.hive.WATBotCrawlerLookup"), 3, - "wat_bot_crawler_lookup"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.parsing.hive.Ip2Str"), 1, "ip2str"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.parsing.hive.Ip2Str"), 3, "ip2str"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.parsing.hive.UserAgentParser"), 2, - "useragentparser"); - - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.lookup.hive.BrowserLookup"), 3, "browserlookup"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.jobs.udf.hive.ConvertIndustryCode"), 1, - "converttoindustryv1"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.urnextractor.hive.UrnExtractorFunctionWrapper"), 1, - "urn_extractor"); - createUDFMapEntry(UDF_MAP, daliToCalciteOp("com.linkedin.stdudfs.hive.daliudfs.UrnExtractorFunctionWrapper"), 1, - "urn_extractor"); - - addDaliUDFs(); - } - - private static void addDaliUDFs() { - ImmutableMultimap registry = HIVE_REGISTRY.getRegistry(); - Converter caseConverter = CaseFormat.UPPER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE); - for (Map.Entry entry : registry.entries()) { - // we cannot use entry.getKey() as function name directly, because keys are all lowercase, which will - // fail to be converted to lowercase with underscore correctly - final String hiveFunctionName = entry.getValue().getFunctionName(); - if (!hiveFunctionName.startsWith("com.linkedin")) { - continue; - } - String[] nameSplit = hiveFunctionName.split("\\."); - // filter above guarantees we've at least 2 entries - String className = nameSplit[nameSplit.length - 1]; - String funcName = caseConverter.convert(className); - SqlOperator op = entry.getValue().getSqlOperator(); - for (int i = op.getOperandCountRange().getMin(); i <= op.getOperandCountRange().getMax(); i++) { - if (!isDaliUDFAlreadyAdded(hiveFunctionName, i)) { - createUDFMapEntry(UDF_MAP, op, i, funcName); - } - } - } - } - - /** - * Gets UDFTransformer for a given Calcite SQL Operator. - * - * @param calciteOpName Name of Calcite SQL operator - * @param numOperands Number of operands - * @return {@link UDFTransformer} object - */ - public static UDFTransformer getUDFTransformer(String calciteOpName, int numOperands) { - return UDF_MAP.get(getKey(calciteOpName, numOperands)); - } - - private static Boolean isDaliUDFAlreadyAdded(String classString, int numOperands) { - return getUDFTransformer(classString, numOperands) != null; - } - - /** - * Looks up Hive functions using functionName case-insensitively. - */ - private static SqlOperator hiveToCalciteOp(String functionName) { - Collection lookup = HIVE_REGISTRY.lookup(functionName); - // TODO: provide overloaded function resolution - return lookup.iterator().next().getSqlOperator(); - } - - /** - * Looks up Dali functions using className case-insensitively. - */ - private static SqlOperator daliToCalciteOp(String className) { - return HIVE_REGISTRY.lookup(className).iterator().next().getSqlOperator(); - } -} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java new file mode 100644 index 000000000..c69065a01 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/CoralToTrinoSqlCallConverter.java @@ -0,0 +1,32 @@ +/** + * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino; + +import java.util.Map; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.util.SqlShuttle; + +import com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil; + + +/** + * This class extends the class of SqlShuttle and calls CalciteTrinoUDFOperatorTransformerUtil to get a list of SqlCallTransformers + * to traverse the hierarchy and converts UDF operator in all SqlCalls if it is required + */ +public class CoralToTrinoSqlCallConverter extends SqlShuttle { + private final Map configs; + public CoralToTrinoSqlCallConverter(Map configs) { + this.configs = configs; + } + + @Override + public SqlNode visit(SqlCall call) { + SqlCall transformedCall = CoralToTrinoSqlCallTransformersUtil.getTransformers(configs).apply(call); + return super.visit(transformedCall); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java index c4ba32562..872b9f8c0 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java @@ -77,7 +77,10 @@ public RelToTrinoConverter(Map configs) { */ public String convert(RelNode relNode) { RelNode rel = convertRel(relNode, configs); - return convertToSqlNode(rel).accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE).toString(); + SqlNode sqlNode = convertToSqlNode(rel); + SqlNode sqlNodeWithUDFOperatorConverted = sqlNode.accept(new CoralToTrinoSqlCallConverter(configs)); + return sqlNodeWithUDFOperatorConverted.accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE) + .toString(); } /** diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFMapUtils.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFMapUtils.java deleted file mode 100644 index 59d5e579f..000000000 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFMapUtils.java +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.trino.rel2trino; - -import java.util.Map; - -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.validate.SqlUserDefinedFunction; - -import com.linkedin.coral.com.google.common.collect.ImmutableList; - - -public class UDFMapUtils { - private UDFMapUtils() { - } - - /** - * Creates a mapping for Calcite SQL operator to Trino UDF. - * - * @param udfMap Map to store the result - * @param calciteOp Calcite SQL operator - * @param numOperands Number of operands - * @param trinoUDFName Name of Trino UDF - */ - static void createUDFMapEntry(Map udfMap, SqlOperator calciteOp, int numOperands, - String trinoUDFName) { - createUDFMapEntry(udfMap, calciteOp, numOperands, trinoUDFName, null, null); - } - - /** - * Creates a mapping from Calcite SQL operator to Trino UDF with Trino SqlOperator, operands transformer, and result transformers. - * - * @param udfMap Map to store the result - * @param calciteOp Calcite SQL operator - * @param numOperands Number of operands - * @param trinoSqlOperator The Trino Sql Operator that is used as the target operator in the map - * @param operandTransformer Operand transformers, null for identity transformation - * @param resultTransformer Result transformer, null for identity transformation - */ - static void createUDFMapEntry(Map udfMap, SqlOperator calciteOp, int numOperands, - SqlOperator trinoSqlOperator, String operandTransformer, String resultTransformer) { - - udfMap.put(getKey(calciteOp.getName(), numOperands), - UDFTransformer.of(calciteOp.getName(), trinoSqlOperator, operandTransformer, resultTransformer, null)); - } - - /** - * Creates a mapping from Calcite SQL operator to Trino UDF with Trino SqlOperator. - * - * @param udfMap Map to store the result - * @param calciteOp Calcite SQL operator - * @param numOperands Number of operands - * @param trinoSqlOperator The Trino Sql Operator that is used as the target operator in the map - */ - static void createUDFMapEntry(Map udfMap, SqlOperator calciteOp, int numOperands, - SqlOperator trinoSqlOperator) { - createUDFMapEntry(udfMap, calciteOp, numOperands, trinoSqlOperator, null, null); - } - - /** - * Creates a mapping from Calcite SQL operator to Trino UDF with Trino UDF name, operands transformer, and result transformers. - * To construct Trino SqlOperator from Trino UDF name, this method reuses the return type inference from calciteOp, - * assuming equivalence. - * - * @param udfMap Map to store the result - * @param calciteOp Calcite SQL operator - * @param numOperands Number of operands - * @param trinoUDFName Name of Trino UDF - * @param operandTransformer Operand transformers, null for identity transformation - * @param resultTransformer Result transformer, null for identity transformation - */ - static void createUDFMapEntry(Map udfMap, SqlOperator calciteOp, int numOperands, - String trinoUDFName, String operandTransformer, String resultTransformer) { - createUDFMapEntry(udfMap, calciteOp, numOperands, createUDF(trinoUDFName, calciteOp.getReturnTypeInference()), - operandTransformer, resultTransformer); - } - - /** - * Creates a mapping from a Calcite SQL operator to a Trino UDF determined at runtime - * by the values of input parameters with operand and result transformers. - * - * @param udfMap Map to store the result - * @param calciteOp Calcite SQL operator - * @param numOperands Number of operands - * @param operatorTransformers Operator transformers as a JSON string. - * @param operandTransformer Operand transformers, null for identity transformation - * @param resultTransformer Result transformer, null for identity transformation - */ - static void createRuntimeUDFMapEntry(Map udfMap, SqlOperator calciteOp, int numOperands, - String operatorTransformers, String operandTransformer, String resultTransformer) { - createUDFMapEntry(udfMap, calciteOp, numOperands, createUDF("", calciteOp.getReturnTypeInference()), - operandTransformer, resultTransformer); - } - - /** - * Creates Trino UDF for a given Trino UDF name and return type inference. - * - * @param udfName udf name - * @param typeInference {@link SqlReturnTypeInference} of return type - * @return SQL operator - */ - public static SqlOperator createUDF(String udfName, SqlReturnTypeInference typeInference) { - return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(udfName), SqlParserPos.ZERO), typeInference, - null, null, null, null); - } - - static String getKey(String calciteOpName, int numOperands) { - return calciteOpName + "_" + numOperands; - } -} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFTransformer.java deleted file mode 100644 index a920441d3..000000000 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/UDFTransformer.java +++ /dev/null @@ -1,376 +0,0 @@ -/** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.trino.rel2trino; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.gson.JsonPrimitive; - -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlWriter; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.validate.SqlUserDefinedFunction; - -import com.linkedin.coral.com.google.common.base.Preconditions; -import com.linkedin.coral.common.functions.FunctionReturnTypes; - - -/** - * Object for transforming UDF from one SQL language to another SQL language at the RexNode layer. - * - * Suppose f1(a1, a2, ..., an) in the first language can be computed by - * f2(b1, b2, ..., bm) in the second language as follows: - * (b1, b2, ..., bm) = g(a1, a2, ..., an) - * f1(a1, a2, ..., an) = h(f2(g(a1, a2, ..., an))) - * - * We need to define two transformation functions: - * - A vector function g for transforming all operands - * - A function h for transforming the result. - * - * This class will represent g and h as expressions in JSON format as follows: - * - Operators: +, -, *, /, and ^ - * - Operands: source operands and literal values - * - * There may also be situations where a function in one language can map to more than one functions in the other - * language depending on the set of input parameters. - * We define a set of matching functions to determine what function name is used. - * Currently, there is no use-case more complicated than matching a parameter string to a static regex. - * - * Example 1: - * In Calcite SQL, TRUNCATE(aDouble, numDigitAfterDot) truncates aDouble by removing - * any digit from the position numDigitAfterDot after the dot, like truncate(11.45, 0) = 11, - * truncate(11.45, 1) = 11.4 - * - * In Trino, TRUNCATE(aDouble) only takes one argument and removes all digits after the dot, - * like truncate(11.45) = 11. - * - * The transformation from Calcite TRUNCATE to Trino TRUNCATE is represented as follows: - * 1. Trino name: TRUNCATE - * - * 2. Operand transformers: - * g(b1) = a1 * 10 ^ a2, with JSON format: - * [ - * { "op":"*", - * "operands":[ - * {"input":1}, // input 0 is reserved for result transformer. source inputs start from 1 - * { "op":"^", - * "operands":[ - * {"value":10}, - * {"input":2}]}]}] - * - * 3. Result transformer: - * h(result) = result / 10 ^ a2 - * { "op":"/", - * "operands":[ - * {"input":0}, // input 0 is for result transformer - * { "op":"^", - * "operands":[ - * {"value":10}, - * {"input":2}]}]}] - * - * - * 4. Operator transformers: - * none - * - * Example 2: - * In Calcite, there exists a hive-derived function to decode binary data given a format, DECODE(binary, scheme). - * In Trino, there is no generic decoding function that takes a decoding-scheme. - * Instead, there exist specific decoding functions that are first-class functions like FROM_UTF8(binary). - * Consequently, we would need to know the operands in the function in order to determine the corresponding call. - * - * The transformation from Calcite DECODE to a Trino equivalent is represented as follows: - * 1. Trino name: There is no function name determined at compile time. - * null - * - * 2. Operand transformers: We want to retain column 1 and drop column 2: - * [{"input":1}] - * - * 3. Result transformer: No transformation is performed on output. - * null - * - * 4. Operator transformers: Check the second parameter (scheme) matches 'utf-8' with any casing using Java Regex. - * [ { - * "regex" : "^.*(?i)(utf-8).*$", - * "input" : 2, - * "name" : "from_utf8" - * } - * ] - */ -public class UDFTransformer { - private static final Map OP_MAP = new HashMap<>(); - - // Operators allowed in the transformation - static { - OP_MAP.put("+", SqlStdOperatorTable.PLUS); - OP_MAP.put("-", SqlStdOperatorTable.MINUS); - OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY); - OP_MAP.put("/", SqlStdOperatorTable.DIVIDE); - OP_MAP.put("^", SqlStdOperatorTable.POWER); - OP_MAP.put("%", SqlStdOperatorTable.MOD); - OP_MAP.put("date", new SqlUserDefinedFunction(new SqlIdentifier("date", SqlParserPos.ZERO), ReturnTypes.DATE, null, - OperandTypes.STRING, null, null)); - OP_MAP.put("timestamp", new SqlUserDefinedFunction(new SqlIdentifier("timestamp", SqlParserPos.ZERO), - FunctionReturnTypes.TIMESTAMP, null, OperandTypes.STRING, null, null) { - @Override - public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { - // for timestamp operator, we need to construct `CAST(x AS TIMESTAMP)` - Preconditions.checkState(call.operandCount() == 1); - final SqlWriter.Frame frame = writer.startFunCall("CAST"); - call.operand(0).unparse(writer, 0, 0); - writer.sep("AS"); - writer.literal("TIMESTAMP"); - writer.endFunCall(frame); - } - }); - OP_MAP.put("hive_pattern_to_trino", - new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO), - FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null)); - } - - public static final String OPERATOR = "op"; - public static final String OPERANDS = "operands"; - /** - * For input node: - * - input equals 0 refers to the result - * - input great than 0 refers to the index of source operand (starting from 1) - */ - public static final String INPUT = "input"; - public static final String VALUE = "value"; - public static final String REGEX = "regex"; - public static final String NAME = "name"; - - public final String calciteOperatorName; - public final SqlOperator targetOperator; - public final List operandTransformers; - public final JsonObject resultTransformer; - public final List operatorTransformers; - - private UDFTransformer(String calciteOperatorName, SqlOperator targetOperator, List operandTransformers, - JsonObject resultTransformer, List operatorTransformers) { - this.calciteOperatorName = calciteOperatorName; - this.targetOperator = targetOperator; - this.operandTransformers = operandTransformers; - this.resultTransformer = resultTransformer; - this.operatorTransformers = operatorTransformers; - } - - /** - * Creates a new transformer. - * - * @param calciteOperatorName Name of the Calcite function associated with this UDF - * @param targetOperator Target operator (a UDF in the target language) - * @param operandTransformers JSON string representing the operand transformations, - * null for identity transformations - * @param resultTransformer JSON string representing the result transformation, - * null for identity transformation - * @param operatorTransformers JSON string representing an array of transformers that can vary the name of the target - * operator based on runtime parameter values. - * In the order of the JSON Array, the first transformer that matches the JSON string will - * have its given operator named selected as the target operator name. - * Operands are indexed beginning at index 1. - * An operatorTransformer has the following serialized JSON string format: - * "[ - * { - * \"name\" : \"{Name of function if this matches}\", - * \"input\" : {Index of the parameter starting at index 1 that is evaluated }, - * \"regex\" : \"{Java Regex string matching the parameter at given input}\" - * }, - * ... - * ]" - * For example, a transformer for a operator named "foo" when parameter 2 matches exactly - * "bar" is specified as: - * "[ - * { - * \"name\" : \"foo\", - * \"input\" : 2, - * \"regex\" : \"'bar'\" - * } - * ]" - * NOTE: A string literal is represented exactly as ['STRING_LITERAL'] with the single - * quotation marks INCLUDED. - * As seen in the example above, the single quotation marks are also present in the - * regex matcher. - * - * @return {@link UDFTransformer} object - */ - - public static UDFTransformer of(@Nonnull String calciteOperatorName, @Nonnull SqlOperator targetOperator, - @Nullable String operandTransformers, @Nullable String resultTransformer, @Nullable String operatorTransformers) { - List operands = null; - JsonObject result = null; - List operators = null; - if (operandTransformers != null) { - operands = parseJsonObjectsFromString(operandTransformers); - } - if (resultTransformer != null) { - result = new JsonParser().parse(resultTransformer).getAsJsonObject(); - } - if (operatorTransformers != null) { - operators = parseJsonObjectsFromString(operatorTransformers); - } - return new UDFTransformer(calciteOperatorName, targetOperator, operands, result, operators); - } - - /** - * Transforms a call to the source operator. - * - * @param rexBuilder Rex Builder - * @param sourceOperands Source operands - * @return An expression calling the target operator that is equivalent to the source operator call - */ - public RexNode transformCall(RexBuilder rexBuilder, List sourceOperands) { - final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands); - if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) { - String operands = sourceOperands.stream().map(RexNode::toString).collect(Collectors.joining(",")); - throw new IllegalArgumentException(String.format( - "An equivalent Trino operator was not found for the function call: %s(%s)", calciteOperatorName, operands)); - } - final List newOperands = transformOperands(rexBuilder, sourceOperands); - final RexNode newCall = rexBuilder.makeCall(newTargetOperator, newOperands); - return transformResult(rexBuilder, newCall, sourceOperands); - } - - private List transformOperands(RexBuilder rexBuilder, List sourceOperands) { - if (operandTransformers == null) { - return sourceOperands; - } - final List sources = new ArrayList<>(); - // Add a dummy expression for input 0 - sources.add(null); - sources.addAll(sourceOperands); - final List results = new ArrayList<>(); - for (JsonObject operandTransformer : operandTransformers) { - results.add(transformExpression(rexBuilder, operandTransformer, sources)); - } - return results; - } - - private RexNode transformResult(RexBuilder rexBuilder, RexNode result, List sourceOperands) { - if (resultTransformer == null) { - return result; - } - final List sources = new ArrayList<>(); - // Result will be input 0 - sources.add(result); - sources.addAll(sourceOperands); - return transformExpression(rexBuilder, resultTransformer, sources); - } - - /** - * Performs a single transformer. - */ - private RexNode transformExpression(RexBuilder rexBuilder, JsonObject transformer, List sourceOperands) { - if (transformer.get(OPERATOR) != null) { - final List inputOperands = new ArrayList<>(); - for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) { - if (inputOperand.isJsonObject()) { - inputOperands.add(transformExpression(rexBuilder, inputOperand.getAsJsonObject(), sourceOperands)); - } - } - final String operatorName = transformer.get(OPERATOR).getAsString(); - final SqlOperator op = OP_MAP.get(operatorName); - if (op == null) { - throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation"); - } - return rexBuilder.makeCall(op, inputOperands); - } - if (transformer.get(INPUT) != null) { - int index = transformer.get(INPUT).getAsInt(); - if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) { - throw new IllegalArgumentException( - "Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size()); - } - return sourceOperands.get(index); - } - final JsonElement value = transformer.get(VALUE); - if (value == null) { - throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value"); - } - if (!value.isJsonPrimitive()) { - throw new IllegalArgumentException("Value should be of primitive type: " + value); - } - - final JsonPrimitive primitive = value.getAsJsonPrimitive(); - if (primitive.isString()) { - return rexBuilder.makeLiteral(primitive.getAsString()); - } - if (primitive.isBoolean()) { - return rexBuilder.makeLiteral(primitive.getAsBoolean()); - } - if (primitive.isNumber()) { - return rexBuilder.makeBigintLiteral(value.getAsBigDecimal()); - } - - throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive); - } - - /** - * Returns a SqlOperator with a function name based on the value of the source operands. - */ - private SqlOperator transformTargetOperator(SqlOperator operator, List sourceOperands) { - if (operatorTransformers == null) { - return operator; - } - - for (JsonObject operatorTransformer : operatorTransformers) { - if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) { - throw new IllegalArgumentException( - "JSON node for target operator transformer must have a matcher, input and name"); - } - // We use the same convention as operand and result transformers. - // Therefore, we start source index values at index 1 instead of index 0. - // Acceptable index values are set to be [1, size] - int index = operatorTransformer.get(INPUT).getAsInt() - 1; - if (index < 0 || index >= sourceOperands.size()) { - throw new IllegalArgumentException( - String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size())); - } - String functionName = operatorTransformer.get(NAME).getAsString(); - if (functionName.isEmpty()) { - throw new IllegalArgumentException("JSON node for transformation must have a non-empty name"); - } - String matcher = operatorTransformer.get(REGEX).getAsString(); - - if (Pattern.matches(matcher, sourceOperands.get(index).toString())) { - return UDFMapUtils.createUDF(functionName, operator.getReturnTypeInference()); - } - } - return operator; - } - - /** - * TODO(ralam): Add this as a general utility in coral-calcite or look for other base libraries with this function. - * Creates an ArrayList of JsonObjects from a string input. - * The input string must be a serialized JSON array. - */ - private static List parseJsonObjectsFromString(String s) { - List objects = new ArrayList<>(); - JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray(); - for (JsonElement object : transformerArray) { - objects.add(object.getAsJsonObject()); - } - return objects; - } -} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateAddOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateAddOperatorTransformer.java new file mode 100644 index 000000000..3ed8602d6 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateAddOperatorTransformer.java @@ -0,0 +1,56 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "date_add" operator with 2 operands into a Trino SqlCall of an operator + * named "date_add" + */ +public class DateAddOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "date_add"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("date_add", hiveToCoralSqlOperator("date_add").getReturnTypeInference()); + + public DateAddOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + newOperands.add(createStringLiteral("day", SqlParserPos.ZERO)); + newOperands.add(sourceOperands.get(1)); + + List timestampOperatorOperands = new ArrayList<>(); + timestampOperatorOperands.add(sourceOperands.get(0)); + SqlCall timestampSqlCall = + TIMESTAMP_OPERATOR.createCall(new SqlNodeList(timestampOperatorOperands, SqlParserPos.ZERO)); + + List dateOperatorOperands = new ArrayList<>(); + dateOperatorOperands.add(timestampSqlCall); + SqlCall dateOpSqlCall = DATE_OPERATOR.createCall(new SqlNodeList(dateOperatorOperands, SqlParserPos.ZERO)); + newOperands.add(dateOpSqlCall); + + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateDiffOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateDiffOperatorTransformer.java new file mode 100644 index 000000000..92c661bd8 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateDiffOperatorTransformer.java @@ -0,0 +1,57 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "datediff" operator with 2 operands into a Trino SqlCall of an operator + * named "date_diff" + */ +public class DateDiffOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "datediff"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("date_diff", hiveToCoralSqlOperator("datediff").getReturnTypeInference()); + + public DateDiffOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + newOperands.add(createStringLiteral("day", SqlParserPos.ZERO)); + newOperands.add(createDateTimestampSqlCall(sourceOperands, 1)); + newOperands.add(createDateTimestampSqlCall(sourceOperands, 0)); + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } + + private SqlCall createDateTimestampSqlCall(List sourceOperands, int idx) { + List timestampOperatorOperands = new ArrayList<>(); + timestampOperatorOperands.add(sourceOperands.get(idx)); + SqlCall timestampSqlCall = + TIMESTAMP_OPERATOR.createCall(new SqlNodeList(timestampOperatorOperands, SqlParserPos.ZERO)); + + List dateOperatorOperands = new ArrayList<>(); + dateOperatorOperands.add(timestampSqlCall); + return DATE_OPERATOR.createCall(new SqlNodeList(dateOperatorOperands, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateSubOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateSubOperatorTransformer.java new file mode 100644 index 000000000..3f9366c62 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DateSubOperatorTransformer.java @@ -0,0 +1,63 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "date_sub" operator with 2 operands into a Trino SqlCall of an operator + * named "date_add" + */ +public class DateSubOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "date_sub"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("date_add", hiveToCoralSqlOperator("date_sub").getReturnTypeInference()); + + public DateSubOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + newOperands.add(createStringLiteral("day", SqlParserPos.ZERO)); + + List multiplyOperands = new ArrayList<>(); + multiplyOperands.add(sourceOperands.get(1)); + multiplyOperands.add(createLiteralNumber(-1, SqlParserPos.ZERO)); + SqlCall multiplySqlCall = + SqlStdOperatorTable.MULTIPLY.createCall(new SqlNodeList(multiplyOperands, SqlParserPos.ZERO)); + newOperands.add(multiplySqlCall); + + List timestampOperatorOperands = new ArrayList<>(); + timestampOperatorOperands.add(sourceOperands.get(0)); + SqlCall timestampSqlCall = + TIMESTAMP_OPERATOR.createCall(new SqlNodeList(timestampOperatorOperands, SqlParserPos.ZERO)); + + List dateOperatorOperands = new ArrayList<>(); + dateOperatorOperands.add(timestampSqlCall); + SqlCall dateOpSqlCall = DATE_OPERATOR.createCall(new SqlNodeList(dateOperatorOperands, SqlParserPos.ZERO)); + newOperands.add(dateOpSqlCall); + + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DecodeOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DecodeOperatorTransformer.java new file mode 100644 index 000000000..bd64e9dbe --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/DecodeOperatorTransformer.java @@ -0,0 +1,44 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "decode" operator with 2 operands into a Trino SqlCall of an operator + * named "[{\"regex\":\"(?i)('utf-8')\", \"input\":2, \"name\":\"from_utf8\"}]" + */ +public class DecodeOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "decode"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("[{\"regex\":\"(?i)('utf-8')\", \"input\":2, \"name\":\"from_utf8\"}]", + hiveToCoralSqlOperator("decode").getReturnTypeInference()); + + public DecodeOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + newOperands.add(sourceOperands.get(0)); + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/MapStructAccessOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/MapStructAccessOperatorTransformer.java new file mode 100644 index 000000000..d36babd28 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/MapStructAccessOperatorTransformer.java @@ -0,0 +1,50 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; + +import com.linkedin.coral.com.google.common.collect.ImmutableList; +import com.linkedin.coral.common.transformers.SqlCallTransformer; + + +/** + * This class is an ad-hoc SqlCallTransformer which converts the map struct access operator "[]" defined + * from Calcite in a SqlIdentifier into a UDF operator of "element_at", + * e.g. from col["field"] to element_at(col, "field") + */ +public class MapStructAccessOperatorTransformer extends SqlCallTransformer { + private static final String AS_OPERATOR_NAME = "AS"; + private static final Pattern MAP_STRUCT_ACCESS_PATTERN = Pattern.compile("\\\".+\\\"\\[\\\".+\\\"\\]"); + private static final String ELEMENT_AT = "element_at(%s, %s)"; + + @Override + protected boolean condition(SqlCall sqlCall) { + if (AS_OPERATOR_NAME.equalsIgnoreCase(sqlCall.getOperator().getName())) { + if (sqlCall.getOperandList().get(0) instanceof SqlIdentifier) { + SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlCall.getOperandList().get(0); + if (sqlIdentifier.names.size() == 2) { + Matcher matcher = MAP_STRUCT_ACCESS_PATTERN.matcher(sqlIdentifier.names.get(0)); + return matcher.find(); + } + } + } + return false; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlCall.getOperandList().get(0); + String[] names = sqlIdentifier.names.get(0).split("\\["); + String newName = String.format(ELEMENT_AT, names[0], names[1].substring(0, names[1].length() - 1)); + sqlIdentifier.names = ImmutableList.of(newName, sqlIdentifier.names.get(1)); + return sqlCall; + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ModOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ModOperatorTransformer.java new file mode 100644 index 000000000..de5da85c5 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ModOperatorTransformer.java @@ -0,0 +1,57 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "pmod" operator with 2 operands into a Trino SqlCall of an operator + * named "mod" + */ +public class ModOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "pmod"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("mod", hiveToCoralSqlOperator(FROM_OPERATOR_NAME).getReturnTypeInference()); + + public ModOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = transformOperands(sourceOperands); + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } + + private List transformOperands(List sourceOperands) { + List newTopLevelOperands = new ArrayList<>(); + + SqlNode modOpSqlNode = SqlStdOperatorTable.MOD.createCall(new SqlNodeList(sourceOperands, SqlParserPos.ZERO)); + List operandsOfPlusOp = new ArrayList<>(); + operandsOfPlusOp.add(modOpSqlNode); + operandsOfPlusOp.add(sourceOperands.get(1)); + SqlNode plusOpSqlNode = SqlStdOperatorTable.PLUS.createCall(new SqlNodeList(operandsOfPlusOp, SqlParserPos.ZERO)); + + newTopLevelOperands.add(plusOpSqlNode); + newTopLevelOperands.add(sourceOperands.get(1)); + return newTopLevelOperands; + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomIntegerOperatorWithTwoOperandsTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomIntegerOperatorWithTwoOperandsTransformer.java new file mode 100644 index 000000000..887d33e4b --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomIntegerOperatorWithTwoOperandsTransformer.java @@ -0,0 +1,42 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; + + +/** + * This class transforms a Coral SqlCall of "RAND_INTEGER" operator with 2 operands into a Trino SqlCall of an operator + * named "RANDOM" + */ +public class RandomIntegerOperatorWithTwoOperandsTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "RAND_INTEGER"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("RANDOM", SqlStdOperatorTable.RAND_INTEGER.getReturnTypeInference()); + + public RandomIntegerOperatorWithTwoOperandsTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List newOperands = new ArrayList<>(); + newOperands.add(sqlCall.getOperandList().get(1)); + return createCall(TARGET_OPERATOR, newOperands, SqlParserPos.ZERO); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomOperatorWithOneOperandTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomOperatorWithOneOperandTransformer.java new file mode 100644 index 000000000..a01392774 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RandomOperatorWithOneOperandTransformer.java @@ -0,0 +1,38 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.Collections; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; + + +/** + * This class transforms a Coral SqlCall of "RAND" operator with 1 operand into a Trino SqlCall of an operator + * named "RANDOM" + */ +public class RandomOperatorWithOneOperandTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "RAND"; + private static final int OPERAND_NUM = 1; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("RANDOM", SqlStdOperatorTable.RAND.getReturnTypeInference()); + + public RandomOperatorWithOneOperandTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + return createCall(TARGET_OPERATOR, Collections.emptyList(), SqlParserPos.ZERO); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RegexpExtractOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RegexpExtractOperatorTransformer.java new file mode 100644 index 000000000..dd1852863 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/RegexpExtractOperatorTransformer.java @@ -0,0 +1,58 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; + +import com.linkedin.coral.common.functions.FunctionReturnTypes; +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + +import static com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil.*; + + +/** + * This class transforms a Coral SqlCall of "regexp_extract" operator with 3 operands into a Trino SqlCall of an operator + * named "regexp_extract" + */ +public class RegexpExtractOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "regexp_extract"; + private static final int OPERAND_NUM = 3; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("regexp_extract", hiveToCoralSqlOperator("regexp_extract").getReturnTypeInference()); + + private static final SqlOperator HIVE_PATTERN_TO_TRINO_OPERATOR = + new SqlUserDefinedFunction(new SqlIdentifier("hive_pattern_to_trino", SqlParserPos.ZERO), + FunctionReturnTypes.STRING, null, OperandTypes.STRING, null, null); + + public RegexpExtractOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + newOperands.add(sourceOperands.get(0)); + + List hivePatternToTrinoOperands = new ArrayList<>(); + hivePatternToTrinoOperands.add(sourceOperands.get(1)); + newOperands + .add(HIVE_PATTERN_TO_TRINO_OPERATOR.createCall(new SqlNodeList(hivePatternToTrinoOperands, SqlParserPos.ZERO))); + + newOperands.add(sourceOperands.get(2)); + return TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ToDateOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ToDateOperatorTransformer.java new file mode 100644 index 000000000..408bbaa0f --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/ToDateOperatorTransformer.java @@ -0,0 +1,53 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.trino.rel2trino.utils.CoralToTrinoSqlCallTransformersUtil; + +import static com.linkedin.coral.common.calcite.CalciteUtil.*; + + +/** + * This class implements the transformation from the operation of "to_date" + * for example, "to_date('2023-01-01')" is transformed into "date(CAST('2023-01-01') AS TIMESTAMP)" + */ +public class ToDateOperatorTransformer extends SqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "to_date"; + private static final String TO_OPERATOR_NAME = "date"; + private static final int NUM_OPERANDS = 1; + private static final SqlOperator TRINO_OPERATOR = createSqlUDF(TO_OPERATOR_NAME, + CoralToTrinoSqlCallTransformersUtil.hiveToCoralSqlOperator(FROM_OPERATOR_NAME).getReturnTypeInference()); + + private final boolean avoidTransformToDateUDF; + + public ToDateOperatorTransformer(boolean avoidTransformToDateUDF) { + this.avoidTransformToDateUDF = avoidTransformToDateUDF; + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return !avoidTransformToDateUDF && FROM_OPERATOR_NAME.equalsIgnoreCase(sqlCall.getOperator().getName()) + && sqlCall.getOperandList().size() == NUM_OPERANDS; + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = new ArrayList<>(); + SqlNode timestampSqlCall = createCall(TIMESTAMP_OPERATOR, sourceOperands, SqlParserPos.ZERO); + newOperands.add(timestampSqlCall); + return createCall(TRINO_OPERATOR, newOperands, SqlParserPos.ZERO); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/TruncateOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/TruncateOperatorTransformer.java new file mode 100644 index 000000000..7119a9c1c --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transfomers/TruncateOperatorTransformer.java @@ -0,0 +1,72 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transfomers; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; + + +/** + * This class transforms a Coral SqlCall of "TRUNCATE" operator with 2 operands into a Trino SqlCall of an operator + * named "TRUNCATE" + */ +public class TruncateOperatorTransformer extends OperatorBasedSqlCallTransformer { + private static final String FROM_OPERATOR_NAME = "TRUNCATE"; + private static final int OPERAND_NUM = 2; + private static final SqlOperator TARGET_OPERATOR = + createSqlUDF("TRUNCATE", SqlStdOperatorTable.TRUNCATE.getReturnTypeInference()); + + public TruncateOperatorTransformer() { + super(FROM_OPERATOR_NAME, OPERAND_NUM, TARGET_OPERATOR); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List sourceOperands = sqlCall.getOperandList(); + List newOperands = transformOperands(sourceOperands); + SqlCall newSqlCall = TARGET_OPERATOR.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + return transformResult(newSqlCall, sourceOperands); + } + + private List transformOperands(List sourceOperands) { + SqlNode powerOpSqlNode = createPowerOpSqlNode(sourceOperands); + + List operandsOfMultiplyOp = new ArrayList<>(); + operandsOfMultiplyOp.add(sourceOperands.get(0)); + operandsOfMultiplyOp.add(powerOpSqlNode); + SqlNode multiplyOpSqlNode = + SqlStdOperatorTable.MULTIPLY.createCall(new SqlNodeList(operandsOfMultiplyOp, SqlParserPos.ZERO)); + + List topLevelOperands = new ArrayList<>(); + topLevelOperands.add(multiplyOpSqlNode); + return topLevelOperands; + } + + private SqlCall transformResult(SqlNode result, List sourceOperands) { + List newOperands = new ArrayList<>(); + newOperands.add(result); + SqlNode powerOpSqlNode = createPowerOpSqlNode(sourceOperands); + newOperands.add(powerOpSqlNode); + return SqlStdOperatorTable.DIVIDE.createCall(new SqlNodeList(newOperands, SqlParserPos.ZERO)); + } + + private SqlCall createPowerOpSqlNode(List sourceOperands) { + List operandsOfPowerOp = new ArrayList<>(); + operandsOfPowerOp.add(SqlLiteral.createExactNumeric(String.valueOf(10), SqlParserPos.ZERO)); + operandsOfPowerOp.add(sourceOperands.get(1)); + return SqlStdOperatorTable.POWER.createCall(new SqlNodeList(operandsOfPowerOp, SqlParserPos.ZERO)); + } +} diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/utils/CoralToTrinoSqlCallTransformersUtil.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/utils/CoralToTrinoSqlCallTransformersUtil.java new file mode 100644 index 000000000..e5c31f6df --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/utils/CoralToTrinoSqlCallTransformersUtil.java @@ -0,0 +1,190 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.com.google.common.base.CaseFormat; +import com.linkedin.coral.com.google.common.base.Converter; +import com.linkedin.coral.com.google.common.collect.ImmutableMultimap; +import com.linkedin.coral.common.functions.Function; +import com.linkedin.coral.common.transformers.OperatorBasedSqlCallTransformer; +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.transformers.SqlCallTransformers; +import com.linkedin.coral.hive.hive2rel.functions.HiveRLikeOperator; +import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; +import com.linkedin.coral.trino.rel2trino.functions.TrinoElementAtFunction; +import com.linkedin.coral.trino.rel2trino.transfomers.DateAddOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.DateDiffOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.DateSubOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.DecodeOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.MapStructAccessOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.ModOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.RandomIntegerOperatorWithTwoOperandsTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.RandomOperatorWithOneOperandTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.RegexpExtractOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.ToDateOperatorTransformer; +import com.linkedin.coral.trino.rel2trino.transfomers.TruncateOperatorTransformer; + +import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*; + + +/** + * This utility class initialize a list of SqlCallTransformer which convert the function operators defined in SqlCalls + * from Coral to Trino on SqlNode layer + */ +public final class CoralToTrinoSqlCallTransformersUtil { + + private static final StaticHiveFunctionRegistry HIVE_FUNCTION_REGISTRY = new StaticHiveFunctionRegistry(); + private static List DEFAULT_SQL_CALL_TRANSFORMER_LIST; + + static { + DEFAULT_SQL_CALL_TRANSFORMER_LIST = new ArrayList<>(); + addCommonSignatureBasedConditionTransformers(); + addAdHocTransformers(); + addLinkedInFunctionTransformers(); + } + + public static SqlCallTransformers getTransformers(Map configs) { + List sqlCallTransformerList = new ArrayList<>(DEFAULT_SQL_CALL_TRANSFORMER_LIST); + // initialize SqlCallTransformer affected by the configuration and add them to the list + sqlCallTransformerList.add(new ToDateOperatorTransformer(configs.getOrDefault(AVOID_TRANSFORM_TO_DATE_UDF, false))); + return SqlCallTransformers.of(ImmutableList.copyOf(sqlCallTransformerList)); + } + + public static SqlOperator hiveToCoralSqlOperator(String functionName) { + Collection lookup = HIVE_FUNCTION_REGISTRY.lookup(functionName); + // TODO: provide overloaded function resolution + return lookup.iterator().next().getSqlOperator(); + } + + private static void addCommonSignatureBasedConditionTransformers() { + // conditional functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("nvl"), 2, "coalesce")); + // Array and map functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add( + new OperatorBasedSqlCallTransformer(SqlStdOperatorTable.ITEM.getName(), 2, TrinoElementAtFunction.INSTANCE)); + + // Math Functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new OperatorBasedSqlCallTransformer(SqlStdOperatorTable.RAND, 0, "RANDOM")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new RandomOperatorWithOneOperandTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(SqlStdOperatorTable.RAND_INTEGER, 1, "RANDOM")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new RandomIntegerOperatorWithTwoOperandsTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new TruncateOperatorTransformer()); + + // String Functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(SqlStdOperatorTable.SUBSTRING, 2, "SUBSTR")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(SqlStdOperatorTable.SUBSTRING, 3, "SUBSTR")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(HiveRLikeOperator.RLIKE, 2, "REGEXP_LIKE")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(HiveRLikeOperator.REGEXP, 2, "REGEXP_LIKE")); + + // JSON Functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("get_json_object"), 2, "json_extract")); + + // map various hive functions + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new ModOperatorTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("base64"), 1, "to_base64")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("unbase64"), 1, "from_base64")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("hex"), 1, "to_hex")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("unhex"), 1, "from_hex")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("array_contains"), 2, "contains")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new RegexpExtractOperatorTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST + .add(new OperatorBasedSqlCallTransformer(hiveToCoralSqlOperator("instr"), 2, "strpos")); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new DecodeOperatorTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new DateAddOperatorTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new DateSubOperatorTransformer()); + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new DateDiffOperatorTransformer()); + } + + private static void addLinkedInFunctionTransformers() { + // Most "com.linkedin..." UDFs follow convention of having UDF names mapped from camel-cased name to snake-cased name. + // For example: For class name IsGuestMemberId, the conventional udf name would be is_guest_member_id. + // While this convention fits most UDFs it doesn't fit all. With the following mapping we override the conventional + // UDF name mapping behavior to a hardcoded one. + // For example instead of UserAgentParser getting mapped to user_agent_parser, we mapped it here to useragentparser + Set linkedInFunctionSignatureSet = new HashSet<>(); + addLinkedInFunctionTransformer("com.linkedin.dali.udf.watbotcrawlerlookup.hive.WATBotCrawlerLookup", 3, + "wat_bot_crawler_lookup", linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.parsing.hive.Ip2Str", 1, "ip2str", + linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.parsing.hive.Ip2Str", 3, "ip2str", + linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.parsing.hive.UserAgentParser", 2, "useragentparser", + linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.lookup.hive.BrowserLookup", 3, "browserlookup", + linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.jobs.udf.hive.ConvertIndustryCode", 1, "converttoindustryv1", + linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.urnextractor.hive.UrnExtractorFunctionWrapper", 1, + "urn_extractor", linkedInFunctionSignatureSet); + addLinkedInFunctionTransformer("com.linkedin.stdudfs.hive.daliudfs.UrnExtractorFunctionWrapper", 1, "urn_extractor", + linkedInFunctionSignatureSet); + + addLinkedInFunctionTransformerFromHiveRegistry(DEFAULT_SQL_CALL_TRANSFORMER_LIST, linkedInFunctionSignatureSet); + } + + private static void addLinkedInFunctionTransformer(String linkedInFuncName, int numOperands, String trinoFuncName, + Set linkedInFunctionSignatureSet) { + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new OperatorBasedSqlCallTransformer( + linkedInFunctionToCoralSqlOperator(linkedInFuncName), numOperands, trinoFuncName)); + linkedInFunctionSignatureSet.add(linkedInFuncName + "_" + numOperands); + } + + private static void addAdHocTransformers() { + DEFAULT_SQL_CALL_TRANSFORMER_LIST.add(new MapStructAccessOperatorTransformer()); + } + + private static SqlOperator linkedInFunctionToCoralSqlOperator(String className) { + return HIVE_FUNCTION_REGISTRY.lookup(className).iterator().next().getSqlOperator(); + } + + private static void addLinkedInFunctionTransformerFromHiveRegistry(List sqlCallTransformerList, + Set linkedInFunctionSignatureSet) { + ImmutableMultimap registry = HIVE_FUNCTION_REGISTRY.getRegistry(); + Converter caseConverter = CaseFormat.UPPER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE); + for (Map.Entry entry : registry.entries()) { + // we cannot use entry.getKey() as function name directly, because keys are all lowercase, which will + // fail to be converted to lowercase with underscore correctly + final String hiveFunctionName = entry.getValue().getFunctionName(); + if (!hiveFunctionName.startsWith("com.linkedin")) { + continue; + } + String[] nameSplit = hiveFunctionName.split("\\."); + // filter above guarantees we've at least 2 entries + String className = nameSplit[nameSplit.length - 1]; + String funcName = caseConverter.convert(className); + SqlOperator op = entry.getValue().getSqlOperator(); + for (int i = op.getOperandCountRange().getMin(); i <= op.getOperandCountRange().getMax(); i++) { + if (!linkedInFunctionSignatureSet.contains(hiveFunctionName + "_" + i)) { + sqlCallTransformerList.add(new OperatorBasedSqlCallTransformer(op, i, funcName)); + } + } + } + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/UDFTransformerTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/UDFTransformerTest.java deleted file mode 100644 index d7267b28a..000000000 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/UDFTransformerTest.java +++ /dev/null @@ -1,244 +0,0 @@ -/** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.trino.rel2trino; - -import java.util.ArrayList; -import java.util.List; - -import com.google.gson.JsonSyntaxException; - -import org.apache.calcite.config.NullCollation; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.rel2sql.RelToSqlConverter; -import org.apache.calcite.rel.rel2sql.SqlImplementor; -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.tools.FrameworkConfig; -import org.testng.annotations.BeforeTest; -import org.testng.annotations.Test; - -import static com.linkedin.coral.trino.rel2trino.TestTable.*; -import static org.testng.Assert.*; - - -public class UDFTransformerTest { - static FrameworkConfig tableOneConfig; - static String tableOneQuery; - static FrameworkConfig tableThreeConfig; - static String tableThreeQuery; - static final SqlOperator targetUDF = UDFMapUtils.createUDF("targetFunc", ReturnTypes.DOUBLE); - static final SqlOperator defaultUDF = UDFMapUtils.createUDF("", ReturnTypes.DOUBLE); - static final SqlDialect sqlDialect = - new SqlDialect(SqlDialect.DatabaseProduct.CALCITE, "Calcite", "", NullCollation.HIGH); - - @BeforeTest - public static void beforeTest() { - tableOneConfig = TestUtils.createFrameworkConfig(TABLE_ONE); - tableThreeConfig = TestUtils.createFrameworkConfig(TABLE_THREE); - tableOneQuery = "select scol, icol, dcol from " + TABLE_ONE.getTableName(); - tableThreeQuery = "select binaryfield, 'literal' from " + TABLE_THREE.getTableName(); - } - - private SqlNode applyTransformation(Project project, SqlOperator operator, String operandTransformer, - String resultTransformer, String operatorTransformer) { - UDFTransformer udfTransformer = - UDFTransformer.of("", operator, operandTransformer, resultTransformer, operatorTransformer); - RexBuilder rexBuilder = project.getCluster().getRexBuilder(); - List sourceOperands = new ArrayList<>(); - List projectOperands = project.getChildExps(); - for (int i = 0; i < projectOperands.size(); ++i) { - // If the parameter is a reference to a column, we make it a RexInputRef. - // We make a new reference because the RexBuilder refers to a column based on the output of the projection. - // We cannot pass the the source operands from the Project operator directly because they are references to the - // columns of the table. - // We need to create a new input reference to each projection output because it is the new input to the UDF. - // Unlike input references, other primitives such as RexLiteral can be added as a sourceOperand directly since - // they are not normally projected as outputs and are not usually input references. - if (projectOperands.get(i) instanceof RexInputRef) { - sourceOperands.add(rexBuilder.makeInputRef(project, i)); - } else { - sourceOperands.add(projectOperands.get(i)); - } - } - RelToSqlConverter sqlConverter = new RelToSqlConverter(sqlDialect); - SqlImplementor.Result result = sqlConverter.visit(project); - RexNode rexNode = udfTransformer.transformCall(rexBuilder, sourceOperands); - return result.builder(project, SqlImplementor.Clause.SELECT).context.toSql(null, rexNode); - } - - private void testTransformation(String query, FrameworkConfig config, SqlOperator operator, String operandTransformer, - String resultTransformer, String operatorTransformer, String expected) { - Project project = (Project) TestUtils.toRel(query, config); - SqlNode result = applyTransformation(project, operator, operandTransformer, resultTransformer, operatorTransformer); - assertEquals(result.toSqlString(sqlDialect).getSql(), expected); - } - - private void testFailedTransformation(String query, FrameworkConfig config, SqlOperator operator, - String operandTransformer, String resultTransformer, String operatorTransformer, Class exceptionClass) { - try { - Project project = (Project) TestUtils.toRel(query, config); - applyTransformation(project, operator, operandTransformer, resultTransformer, operatorTransformer); - fail(); - } catch (Exception e) { - assertTrue(exceptionClass.isInstance(e)); - } - } - - @Test - public void testWrongOperandSyntax() { - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "", null, null, IllegalStateException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "{}", null, null, IllegalStateException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{input}]", null, null, - JsonSyntaxException.class); - } - - @Test - public void testWrongResultSyntax() { - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, null, "", null, IllegalStateException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, null, "[]", null, IllegalStateException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, null, "{", null, JsonSyntaxException.class); - } - - @Test - public void testInvalidInput() { - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"input\":0}]", null, null, - IllegalArgumentException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"input\":4}]", null, null, - IllegalArgumentException.class); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"input\":-1}]", null, null, - IllegalArgumentException.class); - } - - @Test - public void testInvalidJsonNode() { - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"dummy\":0}]", null, null, - IllegalArgumentException.class); - } - - @Test - public void testLiteral() { - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"value\":'astring'}]", null, null, - "targetFunc('astring')"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"value\":true}]", null, null, "targetFunc(TRUE)"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"value\":5}]", null, null, "targetFunc(5)"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"value\":2.5}]", null, null, "targetFunc(2.5)"); - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"value\":[1, 2]}]", null, null, - IllegalArgumentException.class); - } - - @Test - public void testResizeTransformation() { - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"input\":1}, {\"input\":3}]", null, null, - "targetFunc(scol, dcol)"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[]", null, null, "targetFunc()"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, - "[{\"input\":2}, {\"input\":1}, {\"input\":2}, {\"input\":3}]", null, null, - "targetFunc(icol, scol, icol, dcol)"); - } - - @Test - public void testIdentityTransformation() { - testTransformation(tableOneQuery, tableOneConfig, targetUDF, null, null, null, "targetFunc(scol, icol, dcol)"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, "[{\"input\":1}, {\"input\":2}, {\"input\":3}]", - "{\"input\":0}", null, "targetFunc(scol, icol, dcol)"); - } - - @Test - public void testNormalTransformation() { - testTransformation(tableOneQuery, tableOneConfig, targetUDF, - "[{\"op\":\"*\",\"operands\":[{\"input\":2},{\"input\":3}]}, {\"input\":1}]", null, null, - "targetFunc(icol * dcol, scol)"); - testTransformation(tableOneQuery, tableOneConfig, targetUDF, - "[{\"op\":\"*\",\"operands\":[{\"input\":2},{\"input\":3}]}, {\"input\":1}]", - "{\"op\":\"+\",\"operands\":[{\"input\":0},{\"input\":3}]}", null, "targetFunc(icol * dcol, scol) + dcol"); - - testFailedTransformation(tableOneQuery, tableOneConfig, targetUDF, - "[{\"op\":\"@\",\"operands\":[{\"input\":2},{\"input\":3}]}, {\"input\":1}]", null, null, - UnsupportedOperationException.class); - } - - @Test - public void testInputOperatorTransformer() { - // Verifies that an operator transformer that has an exact match uses its target function. - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"'literal'\", \"input\":2, \"name\":\"newFunc\"}]", "newFunc(binaryfield, 'literal')"); - - // Verifies that an operator transformer that has a substring match uses its target function. - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(liter).*$\", \"input\":2, \"name\":\"newFunc\"}]", "newFunc(binaryfield, 'literal')"); - - // Verifies that an operator transformer that has no match uses the default trivial function and throws error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, defaultUDF, null, null, - "[{\"regex\":\"^.*(?i)(noMatch).*$\", \"input\":2, \"name\":\"newFunc\"}]", IllegalArgumentException.class); - } - - @Test - public void testMultipleOperatorTransformers() { - // Verifies that all operator transformers in the operator transformers array are tested. - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(noMatch).*$\", \"input\":2, \"name\":\"unmatchFunc\"}," - + "{\"regex\":\"^.*(?i)(literal).*$\", \"input\":2, \"name\":\"matchFunc\"}]", - "matchFunc(binaryfield, 'literal')"); - - // Verifies that the first target function matcher to match is selected has its target function selected. - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"input\":2, \"name\":\"firstFunc\"}," - + "{\"regex\":\"^.*(?i)(literal).*$\", \"input\":2, \"name\":\"secondFunc\"}]", - "firstFunc(binaryfield, 'literal')"); - } - - @Test - public void testNoMatchOperatorTransformer() { - // Verifies that a target function with no default target UDF throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, null, null, null, - "[{\"regex\":\"^.*(?i)(noMatch).*$\", \"input\":2, \"name\":\"newFunc\"}]", IllegalArgumentException.class); - } - - @Test - public void testInvalidArgumentsOperatorTransformer() { - // Verifies that an empty function name throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"input\":2, \"name\":\"\"}]", IllegalArgumentException.class); - - // Verifies that an input parameter position <= 0 throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"input\":0, \"name\":\"newFunc\"}]", IllegalArgumentException.class); - - // Verifies that an input parameter position > the input size throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"input\":3, \"name\":\"newFunc\"}]", IllegalArgumentException.class); - } - - @Test - public void testSufficientArgumentsOperatorTransformer() { - // Verifies that an operator transformer that does not define a matcher throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"input\":2, \"name\":\"newFunc\"}]", IllegalArgumentException.class); - - // Verifies that an operator transformer that does not define a parameter position throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"name\":\"newFunc\"}]", IllegalArgumentException.class); - - // Verifies that an operator transformer that does not define a replacement function name throws an error. - testFailedTransformation(tableThreeQuery, tableThreeConfig, targetUDF, null, null, - "[{\"regex\":\"^.*(?i)(literal).*$\", \"input\":2}]", IllegalArgumentException.class); - } - - @Test - public void testOperandTransformerAndOperatorTransformer() { - // Verifies that an operator transformer used in conjunction with an operand transformer works correctly. - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, "[{\"input\":1}]", null, - "[{\"regex\":\"'literal'\", \"input\":2, \"name\":\"newFunc\"}]", "newFunc(binaryfield)"); - testTransformation(tableThreeQuery, tableThreeConfig, targetUDF, "[{\"input\":2}]", null, - "[{\"regex\":\"'literal'\", \"input\":2, \"name\":\"newFunc\"}]", "newFunc('literal')"); - } - -}