Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public abstract class ToRelConverter {

protected abstract SqlRexConvertletTable getConvertletTable();

protected abstract SqlValidator getSqlValidator();
public abstract SqlValidator getSqlValidator();

protected abstract SqlOperatorTable getOperatorTable();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;
Expand All @@ -17,7 +16,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.ReflectiveConvertletTable;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
Expand All @@ -26,7 +24,6 @@
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveInOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
Expand All @@ -35,17 +32,6 @@
*/
public class HiveConvertletTable extends ReflectiveConvertletTable {

@SuppressWarnings("unused")
public RexNode convertNamedStruct(SqlRexContext cx, HiveNamedStructFunction func, SqlCall call) {
List<RexNode> operandExpressions = new ArrayList<>(call.operandCount() / 2);
for (int i = 0; i < call.operandCount(); i += 2) {
operandExpressions.add(cx.convertExpression(call.operand(i + 1)));
}
RelDataType retType = cx.getValidator().getValidatedNodeType(call);
RexNode rowNode = cx.getRexBuilder().makeCall(retType, SqlStdOperatorTable.ROW, operandExpressions);
return cx.getRexBuilder().makeCast(retType, rowNode);
}

@SuppressWarnings("unused")
public RexNode convertHiveInOperator(SqlRexContext cx, HiveInOperator operator, SqlCall call) {
List<SqlNode> operandList = call.getOperandList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ protected SqlRexConvertletTable getConvertletTable() {
}

@Override
protected SqlValidator getSqlValidator() {
public SqlValidator getSqlValidator() {
return sqlValidator;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -497,12 +497,11 @@ public void testStructPeekDisallowed() {
public void testStructReturnFieldAccess() {
final String sql = "select named_struct('field_a', 10, 'field_b', 'abc').field_b";
RelNode rel = toRel(sql);
final String expectedRel = "LogicalProject(EXPR$0=[CAST(ROW(10, 'abc')):"
+ "RecordType(INTEGER NOT NULL field_a, CHAR(3) NOT NULL field_b) NOT NULL.field_b])\n"
final String expectedRel = "LogicalProject(EXPR$0=[named_struct('field_a', 10, 'field_b', 'abc').field_b])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(relToStr(rel), expectedRel);
final String expectedSql = "SELECT CAST(ROW(10, 'abc') AS ROW(field_a INTEGER, field_b CHAR(3))).field_b\n"
+ "FROM (VALUES (0)) t (ZERO)";
final String expectedSql =
"SELECT named_struct('field_a', 10, 'field_b', 'abc').field_b\n" + "FROM (VALUES (0)) t (ZERO)";
assertEquals(relToHql(rel), expectedSql);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -44,8 +44,7 @@ public void testMixedTypes() {
final String sql = "SELECT named_struct('abc', 123, 'def', 'xyz')";
RelNode rel = toRel(sql);
final String generated = relToStr(rel);
final String expected = ""
+ "LogicalProject(EXPR$0=[CAST(ROW(123, 'xyz')):RecordType(INTEGER NOT NULL abc, CHAR(3) NOT NULL def) NOT NULL])\n"
final String expected = "" + "LogicalProject(EXPR$0=[named_struct('abc', 123, 'def', 'xyz')])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -54,9 +53,8 @@ public void testMixedTypes() {
public void testNullFieldValue() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', 150)";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, 150)):RecordType(INTEGER abc, INTEGER NOT NULL def) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected = "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', 150)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}

Expand All @@ -65,7 +63,7 @@ public void testAllNullValues() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', cast(NULL as double))";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, CAST(null:NULL):DOUBLE)):RecordType(INTEGER abc, DOUBLE def) NOT NULL])\n"
"LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', CAST(null:NULL):DOUBLE)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -74,10 +72,9 @@ public void testAllNullValues() {
public void testNestedComplexTypes() {
final String sql = "SELECT named_struct('arr', array(10, 15), 's', named_struct('f1', 123, 'f2', array(20.5)))";
final String generated = sqlToRelStr(sql);
final String expected = "LogicalProject(EXPR$0=[CAST(ROW(ARRAY(10, 15), CAST(ROW(123, ARRAY(20.5:DECIMAL(3, 1)))):"
+ "RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL)):"
+ "RecordType(INTEGER NOT NULL ARRAY NOT NULL arr, RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL s) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected =
"LogicalProject(EXPR$0=[named_struct('arr', ARRAY(10, 15), 's', named_struct('f1', 123, 'f2', ARRAY(20.5:DECIMAL(3, 1))))])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
// verified by human that expected string is correct and retained here to protect from future changes
assertEquals(generated, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,14 @@ public void testAvoidCastToRow() {
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}

@Test
public void testSimpleCast() {
RelNode relNode = TestUtils.toRelNode("SELECT cast(1 as bigint)");

String targetSql = "SELECT CAST(1 AS BIGINT)\n" + "FROM (VALUES (0)) t (ZERO)";
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}

@Test
public void testCastOnString() {
RelNode relNode = TestUtils.toRelNode("SELECT CAST('99999999999' AS BIGINT) > 0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlValidator;

import com.linkedin.coral.common.functions.Function;
import com.linkedin.coral.common.transformers.JsonTransformSqlCallTransformer;
Expand All @@ -28,6 +29,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.CurrentTimestampTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericCoralRegistryOperatorRenameSqlCallTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.MapValueConstructorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructOperandTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ReturnTypeAdjustmentTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SqlSelectAliasAppenderTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ToDateOperatorTransformer;
Expand All @@ -44,12 +46,14 @@ public class CoralToTrinoSqlCallConverter extends SqlShuttle {
private static final StaticHiveFunctionRegistry HIVE_FUNCTION_REGISTRY = new StaticHiveFunctionRegistry();
private final SqlCallTransformers sqlCallTransformers;

public CoralToTrinoSqlCallConverter(Map<String, Boolean> configs) {
public CoralToTrinoSqlCallConverter(Map<String, Boolean> configs, SqlValidator sqlValidator) {
this.sqlCallTransformers = SqlCallTransformers.of(new SqlSelectAliasAppenderTransformer(),
// conditional functions
new CoralRegistryOperatorRenameSqlCallTransformer("nvl", 2, "coalesce"),
// array and map functions
new MapValueConstructorTransformer(),
// named_struct to cast as row
new NamedStructOperandTransformer(sqlValidator),
new OperatorRenameSqlCallTransformer(SqlStdOperatorTable.SUBSTRING, 3, "SUBSTR"),
new SourceOperatorMatchSqlCallTransformer("item", 2) {
@Override
Expand Down Expand Up @@ -130,6 +134,6 @@ private SqlOperator hiveToCoralSqlOperator(String functionName) {

@Override
public SqlNode visit(SqlCall call) {
return sqlCallTransformers.apply((SqlCall) super.visit(call));
return super.visit(sqlCallTransformers.apply(call));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
Expand Down Expand Up @@ -81,7 +82,20 @@ public RelToTrinoConverter(Map<String, Boolean> configs) {
public String convert(RelNode relNode) {
RelNode rel = convertRel(relNode, configs);
SqlNode sqlNode = convertToSqlNode(rel);
SqlNode sqlNodeWithUDFOperatorConverted = sqlNode.accept(new CoralToTrinoSqlCallConverter(configs));
SqlNode sqlNodeWithUDFOperatorConverted = sqlNode.accept(new CoralToTrinoSqlCallConverter(configs, null));
return sqlNodeWithUDFOperatorConverted.accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE)
.toString();
}

/**
* Convert relational algebra to Trino's SQL
* @param relNode calcite relational algebra representation of SQL
* @return SQL string
*/
public String convert(RelNode relNode, SqlValidator sqlValidator) {
RelNode rel = convertRel(relNode, configs);
SqlNode sqlNode = convertToSqlNode(rel);
SqlNode sqlNodeWithUDFOperatorConverted = sqlNode.accept(new CoralToTrinoSqlCallConverter(configs, sqlValidator));
return sqlNodeWithUDFOperatorConverted.accept(new TrinoSqlRewriter()).toSqlString(TrinoSqlDialect.INSTANCE)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.HashSet;
import java.util.Set;

import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
Expand All @@ -23,14 +26,21 @@
public class CurrentTimestampTransformer extends SqlCallTransformer {

private static final String CURRENT_TIMESTAMP_FUNCTION_NAME = "CURRENT_TIMESTAMP";
private final Set<SqlCall> visited;

public CurrentTimestampTransformer() {
visited = new HashSet<>();
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().getName().equalsIgnoreCase(CURRENT_TIMESTAMP_FUNCTION_NAME);
return sqlCall.getOperator().getName().equalsIgnoreCase(CURRENT_TIMESTAMP_FUNCTION_NAME)
&& !visited.contains(sqlCall);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
visited.add(sqlCall);
SqlDataTypeSpec timestampType =
new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.TIMESTAMP, 3, SqlParserPos.ZERO), SqlParserPos.ZERO);
return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, sqlCall, timestampType);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.List;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlRowTypeSpec;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;

import static org.apache.calcite.sql.parser.SqlParserPos.ZERO;


/**
* Converts Coral's named_struct function to CAST AS ROW(types) function.
*/
public class NamedStructOperandTransformer extends SqlCallTransformer {

public NamedStructOperandTransformer(SqlValidator sqlValidator) {
super(sqlValidator);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().equals(HiveNamedStructFunction.NAMED_STRUCT);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> inputOperands = sqlCall.getOperandList();

List<SqlDataTypeSpec> rowTypes = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
for (int i = 0; i < inputOperands.size(); i += 2) {
assert inputOperands.get(i) instanceof SqlLiteral;
fieldNames.add(((SqlLiteral) inputOperands.get(i)).getStringValue());
}

List<SqlNode> rowCallOperands = new ArrayList<>();
for (int i = 1; i < inputOperands.size(); i += 2) {
rowCallOperands.add(inputOperands.get(i));
RelDataType type = getRelDataType(inputOperands.get(i));
rowTypes.add(SqlTypeUtil.convertTypeToSpec(type));
}
SqlNode rowCall = SqlStdOperatorTable.ROW.createCall(ZERO, rowCallOperands);
return SqlStdOperatorTable.CAST.createCall(ZERO, rowCall, new SqlRowTypeSpec(fieldNames, rowTypes, ZERO));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.ImmutableMap;

Expand Down Expand Up @@ -36,6 +38,7 @@ public class ReturnTypeAdjustmentTransformer extends SqlCallTransformer {
.put("date_diff", SqlTypeName.INTEGER).put("cardinality", SqlTypeName.INTEGER).put("ceil", SqlTypeName.BIGINT)
.put("ceiling", SqlTypeName.BIGINT).put("floor", SqlTypeName.BIGINT).put("date_add", SqlTypeName.VARCHAR).build();
private final Map<String, Boolean> configs;
private final Set<SqlCall> visited = new HashSet<>();

public ReturnTypeAdjustmentTransformer(Map<String, Boolean> configs) {
this.configs = configs;
Expand All @@ -47,11 +50,12 @@ protected boolean condition(SqlCall sqlCall) {
if ("date_add".equals(lowercaseOperatorName) && !configs.getOrDefault(CAST_DATEADD_TO_STRING, false)) {
return false;
}
return OPERATORS_TO_ADJUST.containsKey(lowercaseOperatorName);
return OPERATORS_TO_ADJUST.containsKey(lowercaseOperatorName) && !visited.contains(sqlCall);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
visited.add(sqlCall);
String lowercaseOperatorName = sqlCall.getOperator().getName().toLowerCase(Locale.ROOT);
SqlTypeName targetType = OPERATORS_TO_ADJUST.get(lowercaseOperatorName);
if (targetType != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ protected SqlRexConvertletTable getConvertletTable() {
}

@Override
protected SqlValidator getSqlValidator() {
public SqlValidator getSqlValidator() {
return sqlValidator;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ public void testIfWithNullAsSecondParameter() {
"SELECT \"if\"(FALSE, NULL, CAST(ROW('') AS ROW(\"a\" CHAR(0))))\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")";

RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
String expandedSql = relToTrinoConverter.convert(relNode);
String expandedSql = relToTrinoConverter.convert(relNode, TestUtils.getHiveToRelConverter().getSqlValidator());
assertEquals(expandedSql, targetSql);
}

Expand All @@ -361,7 +361,7 @@ public void testIfWithNullAsThirdParameter() {
"SELECT \"if\"(FALSE, CAST(ROW('') AS ROW(\"a\" CHAR(0))), NULL)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")";

RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
String expandedSql = relToTrinoConverter.convert(relNode);
String expandedSql = relToTrinoConverter.convert(relNode, TestUtils.getHiveToRelConverter().getSqlValidator());
assertEquals(expandedSql, targetSql);
}

Expand Down