diff --git a/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java b/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java deleted file mode 100644 index f947da078..000000000 --- a/coral-common/src/main/java/com/linkedin/coral/common/functions/FunctionFieldReferenceOperator.java +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. - */ -package com.linkedin.coral.common.functions; - -import com.google.common.base.Preconditions; - -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.sql.SqlBasicCall; -import org.apache.calcite.sql.SqlBinaryOperator; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlCharStringLiteral; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlLiteral; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlWriter; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.util.SqlBasicVisitor; -import org.apache.calcite.sql.util.SqlVisitor; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; - - -/** - * Operator to reference fields of structs returned by SQL functions. - * This supports following SQL: - * {@code - * SELECT f(col_1, col_2).field_a FROM myTable - * } - * where {@code f} is a function that returns a ROW type containing {@code field_a}. - * - * TODO: Fix calcite and fold this into Calcite DOT operator - * - */ -public class FunctionFieldReferenceOperator extends SqlBinaryOperator { - public static final FunctionFieldReferenceOperator DOT = new FunctionFieldReferenceOperator(); - - public FunctionFieldReferenceOperator() { - super(".", SqlKind.DOT, 80, true, null, null, OperandTypes.ANY_ANY); - } - - @Override - public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { - Preconditions.checkState(operands.length == 2); - SqlCharStringLiteral fieldName = SqlLiteral.createCharString(fieldNameStripQuotes(operands[1]), SqlParserPos.ZERO); - return super.createCall(functionQualifier, pos, operands[0], fieldName); - } - - @Override - public void acceptCall(SqlVisitor visitor, SqlCall call, boolean onlyExpressions, - SqlBasicVisitor.ArgHandler argHandler) { - argHandler.visitChild(visitor, call, 0, call.operand(0)); - } - - @Override - public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { - call.operand(0).unparse(writer, getLeftPrec(), getRightPrec()); - writer.literal("."); - writer.setNeedWhitespace(false); - // strip quotes from fieldName - String fieldName = fieldNameStripQuotes(call.operand(1)); - writer.identifier(fieldName, true); - } - - @Override - public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { - SqlNode firstOperand = call.operand(0); - if (firstOperand instanceof SqlBasicCall) { - RelDataType funcType = validator.deriveType(scope, firstOperand); - if (funcType.isStruct()) { - return funcType.getField(fieldNameStripQuotes(call.operand(1)), false, false).getType(); - } - } - return super.deriveType(validator, scope, call); - } - - @Override - public void validateCall(SqlCall call, SqlValidator validator, SqlValidatorScope scope, - SqlValidatorScope operandScope) { - call.operand(0).validateExpr(validator, operandScope); - } - - public static String fieldNameStripQuotes(SqlNode node) { - return stripQuotes(fieldName(node)); - } - - public static String fieldName(SqlNode node) { - switch (node.getKind()) { - case IDENTIFIER: - return ((SqlIdentifier) node).getSimple(); - case LITERAL: - return ((SqlLiteral) node).toValue(); - default: - throw new IllegalStateException( - String.format("Unknown operand type %s to reference a field, operand: %s", node.getKind(), node)); - } - } - - private static String stripQuotes(String id) { - if ((id.startsWith("'") && id.endsWith("'")) || (id.startsWith("\"") && id.endsWith("\""))) { - return id.substring(1, id.length() - 1); - } - return id; - } - -} diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/CoralConvertletTable.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/CoralConvertletTable.java index 774e81d0f..6ed200cba 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/CoralConvertletTable.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/CoralConvertletTable.java @@ -16,8 +16,6 @@ import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.sql2rel.StandardConvertletTable; -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; - /** * ConvertletTable for transformations only relevant to Coral's Intermediate Representation, not specific @@ -26,14 +24,6 @@ */ public class CoralConvertletTable extends ReflectiveConvertletTable { - @SuppressWarnings("unused") - public RexNode convertFunctionFieldReferenceOperator(SqlRexContext cx, FunctionFieldReferenceOperator op, - SqlCall call) { - RexNode funcExpr = cx.convertExpression(call.operand(0)); - String fieldName = FunctionFieldReferenceOperator.fieldNameStripQuotes(call.operand(1)); - return cx.getRexBuilder().makeFieldAccess(funcExpr, fieldName, false); - } - /** * Override {@link StandardConvertletTable#convertCast} to avoid cast optimizations that remove the cast. */ diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlValidator.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlValidator.java index 66ce83bb5..3736c268c 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlValidator.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveSqlValidator.java @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2017-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -9,7 +9,6 @@ import org.apache.calcite.config.NullCollation; import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlInsert; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorTable; @@ -19,8 +18,6 @@ import org.apache.calcite.sql.validate.SqlValidatorImpl; import org.apache.calcite.sql.validate.SqlValidatorScope; -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; - public class HiveSqlValidator extends SqlValidatorImpl { @@ -51,15 +48,4 @@ protected void inferUnknownTypes(RelDataType inferredType, SqlValidatorScope sco super.inferUnknownTypes(inferredType, scope, node); } - @Override - public SqlNode expand(SqlNode expr, SqlValidatorScope scope) { - if (expr instanceof SqlBasicCall - && ((SqlBasicCall) expr).getOperator().equals(FunctionFieldReferenceOperator.DOT)) { - SqlBasicCall dotCall = (SqlBasicCall) expr; - if (dotCall.operand(0) instanceof SqlBasicCall) { - return expr; - } - } - return super.expand(expr, scope); - } } diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java index 8bf1ac4a0..079442f36 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java @@ -48,7 +48,6 @@ import com.linkedin.coral.com.google.common.collect.Iterables; import com.linkedin.coral.common.functions.CoralSqlUnnestOperator; import com.linkedin.coral.common.functions.Function; -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; import com.linkedin.coral.hive.hive2rel.functions.HiveFunctionResolver; import com.linkedin.coral.hive.hive2rel.functions.HiveJsonTupleOperator; import com.linkedin.coral.hive.hive2rel.functions.HiveRLikeOperator; @@ -523,7 +522,7 @@ protected SqlNode visitDotOperator(ASTNode node, ParseContext ctx) { Iterable names = Iterables.concat(left.names, right.names); return new SqlIdentifier(ImmutableList.copyOf(names), ZERO); } else { - return FunctionFieldReferenceOperator.DOT.createCall(ZERO, sqlNodes); + return SqlStdOperatorTable.DOT.createCall(ZERO, sqlNodes); } } diff --git a/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java b/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java index 111bfcaa7..4e4a2a2ec 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java +++ b/coral-hive/src/main/java/com/linkedin/coral/transformers/CoralRelToSqlNodeConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2017-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -49,7 +49,6 @@ import com.linkedin.coral.com.google.common.collect.ImmutableList; import com.linkedin.coral.com.google.common.collect.ImmutableMap; import com.linkedin.coral.common.functions.CoralSqlUnnestOperator; -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; /** @@ -474,7 +473,7 @@ private SqlNode generateRightChildForSqlJoinWithLateralViews(BiRel e, Result rig * Calcite converts it to a {@link SqlIdentifier} with {@link SqlIdentifier#names} as ["f(x)", "y"] where "f(x)" and "y" are String, * which is opaque and not aligned with our expectation, since we want to apply transformations on `f(x)` with * {@link com.linkedin.coral.common.transformers.SqlCallTransformer}. Therefore, we override this - * method to convert `f(x)` to {@link SqlCall} and `.` to {@link com.linkedin.coral.common.functions.FunctionFieldReferenceOperator#DOT}. + * method to convert `f(x)` to {@link SqlCall} and `.` to {@link SqlStdOperatorTable#DOT}. * * With this override, the converted CoralSqlNode matches the previous SqlNode handed over to Calcite for validation and conversion * in `HiveSqlToRelConverter#convertQuery`. @@ -500,7 +499,7 @@ public SqlNode toSql(RexProgram program, RexNode rex) { SqlNode functionCall = toSql(program, referencedExpr); Collections.reverse(accessNames); for (String accessName : accessNames) { - functionCall = FunctionFieldReferenceOperator.DOT.createCall(SqlParserPos.ZERO, functionCall, + functionCall = SqlStdOperatorTable.DOT.createCall(SqlParserPos.ZERO, functionCall, new SqlIdentifier(accessName, POS)); } return functionCall; diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index 454ea0ac7..a6732e4d5 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -440,7 +440,7 @@ public void testUnionExtractUDFOnSingleTypeUnions() { RelNode relNode3 = TestUtils.toRelNode("SELECT extract_union(baz).single.tag_0 from union_table"); String targetSql4 = - "SELECT (coalesce_struct(union_table.baz, 'struct>>').single).tag_0\n" + "SELECT coalesce_struct(union_table.baz, 'struct>>').single.tag_0\n" + "FROM default.union_table union_table"; assertEquals(createCoralSpark(relNode3).getSparkSql(), targetSql4); } @@ -876,9 +876,17 @@ public void testSelectStar() { @Test public void testConvertFieldAccessOnFunctionCall() { + RelNode relNode = TestUtils.toRelNode("SELECT named_struct('a', 1).a"); + + String targetSql = "SELECT named_struct('a', 1).a\n" + "FROM (VALUES (0)) t (ZERO)"; + assertEquals(createCoralSpark(relNode).getSparkSql(), targetSql); + } + + @Test + public void testConvertNestedFieldAccessOnFunctionCall() { RelNode relNode = TestUtils.toRelNode("SELECT named_struct('a', named_struct('b', 1)).a.b"); - String targetSql = "SELECT (named_struct('a', named_struct('b', 1)).a).b\n" + "FROM (VALUES (0)) t (ZERO)"; + String targetSql = "SELECT named_struct('a', named_struct('b', 1)).a.b\n" + "FROM (VALUES (0)) t (ZERO)"; assertEquals(createCoralSpark(relNode).getSparkSql(), targetSql); } diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java index a2b3c6145..723487853 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/RelToTrinoConverter.java @@ -31,7 +31,6 @@ import com.linkedin.coral.com.google.common.collect.ImmutableList; import com.linkedin.coral.common.HiveMetastoreClient; import com.linkedin.coral.common.functions.CoralSqlUnnestOperator; -import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; import com.linkedin.coral.transformers.CoralRelToSqlNodeConverter; import static com.google.common.base.Preconditions.*; @@ -441,7 +440,7 @@ public SqlNode toSql(RexProgram program, RexNode rex) { SqlNode functionCall = toSql(program, referencedExpr); Collections.reverse(accessNames); for (String accessName : accessNames) { - functionCall = FunctionFieldReferenceOperator.DOT.createCall(SqlParserPos.ZERO, functionCall, + functionCall = SqlStdOperatorTable.DOT.createCall(SqlParserPos.ZERO, functionCall, new SqlIdentifier(accessName, POS)); } return functionCall;