diff --git a/coral-common/src/main/java/com/linkedin/coral/common/ToRelConverter.java b/coral-common/src/main/java/com/linkedin/coral/common/ToRelConverter.java index 69fa751aa..40321704b 100644 --- a/coral-common/src/main/java/com/linkedin/coral/common/ToRelConverter.java +++ b/coral-common/src/main/java/com/linkedin/coral/common/ToRelConverter.java @@ -133,8 +133,6 @@ public RelNode convertView(String hiveDbName, String hiveViewName) { return toRel(sqlNode); } - // TODO change back to protected once the relevant tests move to the common package - @VisibleForTesting public SqlNode toSqlNode(String sql) { return toSqlNode(sql, null); } @@ -161,9 +159,9 @@ public SqlNode processView(String dbName, String tableName) { return toSqlNode(stringViewExpandedText, table); } - @VisibleForTesting - protected RelNode toRel(SqlNode sqlNode) { + public RelNode toRel(SqlNode sqlNode) { RelRoot root = getSqlToRelConverter().convertQuery(sqlNode, true, true); + return standardizeRel(root.rel); } diff --git a/coral-common/src/main/java/com/linkedin/coral/common/calcite/DdlSqlValidator.java b/coral-common/src/main/java/com/linkedin/coral/common/calcite/DdlSqlValidator.java new file mode 100644 index 000000000..ffc8576e4 --- /dev/null +++ b/coral-common/src/main/java/com/linkedin/coral/common/calcite/DdlSqlValidator.java @@ -0,0 +1,14 @@ +/** + * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.common.calcite; + +import org.apache.calcite.sql.SqlNode; + + +public interface DdlSqlValidator { + + void validate(SqlNode ddlSqlNode); +} diff --git a/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCommand.java b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCommand.java new file mode 100644 index 000000000..6158533dc --- /dev/null +++ b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCommand.java @@ -0,0 +1,19 @@ +/** + * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.common.calcite.sql; + +import org.apache.calcite.sql.SqlNode; + + +/** + * Interface for SqlNodes containing select statements as a child node. Ex: CTAS queries + */ +public interface SqlCommand { + + SqlNode getSelectQuery(); + + void setSelectQuery(SqlNode selectQuery); +} diff --git a/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCreateTable.java b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCreateTable.java new file mode 100644 index 000000000..8402dbcfc --- /dev/null +++ b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/SqlCreateTable.java @@ -0,0 +1,131 @@ +/** + * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.common.calcite.sql; + +import java.util.List; + +import org.apache.calcite.sql.*; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.ImmutableNullableList; + +import com.linkedin.coral.javax.annotation.Nonnull; +import com.linkedin.coral.javax.annotation.Nullable; + + +/** + * SQL parse tree node to represent {@code CREATE} statements, + *

Supported Syntax: + * + *

+ * CREATE TABLE [ IF NOT EXISTS ] name + * [ROW FORMAT SERDE serde] + * [ROW FORMAT DELIMITED FIELDS TERMINATED BY rowFormat] + * [STORED AS fileFormat] + * [STORED AS INPUTFORMAT inputFormat STORED AS OUTPUTFORMAT outputFormat] + * [ AS query ] + * + *
+ * + *

Examples: + * + *

+ */ +public class SqlCreateTable extends SqlCreate implements SqlCommand { + // name of the table to be created + private final SqlIdentifier name; + // column details like column name, data type, etc. This may be null, like in case of CTAS + private final @Nullable SqlNodeList columnList; + // select query node in case of "CREATE TABLE ... AS query"; else may be null + private @Nullable SqlNode selectQuery; + // specifying serde property + private final @Nullable SqlNode serDe; + // specifying file format such as Parquet, ORC, etc. + private final @Nullable SqlNodeList fileFormat; + // specifying delimiter fields for row format + private final @Nullable SqlCharStringLiteral rowFormat; + + private static final SqlOperator OPERATOR = new SqlSpecialOperator("CREATE TABLE", SqlKind.CREATE_TABLE); + + /** Creates a SqlCreateTable. */ + public SqlCreateTable(SqlParserPos pos, boolean replace, boolean ifNotExists, @Nonnull SqlIdentifier name, + @Nullable SqlNodeList columnList, @Nullable SqlNode selectQuery, @Nullable SqlNode serDe, + @Nullable SqlNodeList fileFormat, @Nullable SqlCharStringLiteral rowFormat) { + super(OPERATOR, pos, replace, ifNotExists); + this.name = name; + this.columnList = columnList; + this.selectQuery = selectQuery; + this.serDe = serDe; + this.fileFormat = fileFormat; + this.rowFormat = rowFormat; + } + + @SuppressWarnings("nullness") + @Override + public List getOperandList() { + return ImmutableNullableList.of(name, columnList, selectQuery, serDe, fileFormat, rowFormat); + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CREATE"); + writer.keyword("TABLE"); + if (ifNotExists) { + writer.keyword("IF NOT EXISTS"); + } + name.unparse(writer, leftPrec, rightPrec); + if (columnList != null) { + SqlWriter.Frame frame = writer.startList("(", ")"); + for (SqlNode c : columnList) { + writer.sep(","); + c.unparse(writer, 0, 0); + } + writer.endList(frame); + } + if (serDe != null) { + writer.keyword("ROW FORMAT SERDE"); + serDe.unparse(writer, 0, 0); + writer.newlineAndIndent(); + } + if (rowFormat != null) { + writer.keyword("ROW FORMAT DELIMITED FIELDS TERMINATED BY"); + rowFormat.unparse(writer, 0, 0); + writer.newlineAndIndent(); + } + if (fileFormat != null) { + if (fileFormat.size() == 1) { + writer.keyword("STORED AS"); + fileFormat.get(0).unparse(writer, 0, 0); + } else { + writer.keyword("STORED AS INPUTFORMAT"); + fileFormat.get(0).unparse(writer, 0, 0); + writer.keyword("OUTPUTFORMAT"); + fileFormat.get(1).unparse(writer, 0, 0); + } + writer.newlineAndIndent(); + } + if (selectQuery != null) { + writer.keyword("AS"); + writer.newlineAndIndent(); + selectQuery.unparse(writer, 0, 0); + } + } + + @Override + public SqlNode getSelectQuery() { + return selectQuery; + } + + @Override + public void setSelectQuery(SqlNode query) { + this.selectQuery = query; + } +} diff --git a/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/util/SqlDdlNodes.java b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/util/SqlDdlNodes.java new file mode 100644 index 000000000..f5997bee9 --- /dev/null +++ b/coral-common/src/main/java/com/linkedin/coral/common/calcite/sql/util/SqlDdlNodes.java @@ -0,0 +1,26 @@ +/** + * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.common.calcite.sql.util; + +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.parser.SqlParserPos; + +import com.linkedin.coral.common.calcite.sql.SqlCreateTable; + + +public class SqlDdlNodes { + + /** Creates a CREATE TABLE. */ + public static SqlCreateTable createTable(SqlParserPos pos, boolean replace, boolean ifNotExists, SqlIdentifier name, + SqlNodeList columnList, SqlNode query, SqlNode tableSerializer, SqlNodeList tableFileFormat, + SqlCharStringLiteral tableRowFormat) { + return new SqlCreateTable(pos, replace, ifNotExists, name, columnList, query, tableSerializer, tableFileFormat, + tableRowFormat); + } +} diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java index ad1aaf081..dbe8bb29a 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/HiveToRelConverter.java @@ -24,9 +24,11 @@ import com.linkedin.coral.common.HiveMetastoreClient; import com.linkedin.coral.common.ToRelConverter; +import com.linkedin.coral.common.calcite.DdlSqlValidator; import com.linkedin.coral.hive.hive2rel.functions.HiveFunctionResolver; import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; import com.linkedin.coral.hive.hive2rel.parsetree.ParseTreeBuilder; +import com.linkedin.coral.hive.hive2rel.validators.HiveDdlSqlValidator; import static com.linkedin.coral.hive.hive2rel.HiveSqlConformance.HIVE_SQL; @@ -52,6 +54,7 @@ public class HiveToRelConverter extends ToRelConverter { // The validator must be reused SqlValidator sqlValidator = new HiveSqlValidator(getOperatorTable(), getCalciteCatalogReader(), ((JavaTypeFactory) getRelBuilder().getTypeFactory()), HIVE_SQL); + DdlSqlValidator ddlSqlValidator = new HiveDdlSqlValidator(); public HiveToRelConverter(HiveMetastoreClient hiveMetastoreClient) { super(hiveMetastoreClient); @@ -92,7 +95,9 @@ protected SqlToRelConverter getSqlToRelConverter() { @Override protected SqlNode toSqlNode(String sql, Table hiveView) { - return parseTreeBuilder.process(trimParenthesis(sql), hiveView); + SqlNode sqlNode = parseTreeBuilder.process(trimParenthesis(sql), hiveView); + ddlSqlValidator.validate(sqlNode); + return sqlNode; } @Override diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java index 8f71bf0cb..068874844 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/AbstractASTVisitor.java @@ -302,12 +302,114 @@ protected R visit(ASTNode node, C ctx) { case HiveParser.KW_CURRENT: return visitCurrentRow(node, ctx); + case HiveParser.TOK_CREATETABLE: + return visitCreateTable(node, ctx); + case HiveParser.TOK_LIKETABLE: + return visitLikeTable(node, ctx); + case HiveParser.TOK_IFNOTEXISTS: + return visitIfNotExists(node, ctx); + case HiveParser.TOK_TABCOLLIST: + return visitColumnList(node, ctx); + case HiveParser.TOK_TABCOL: + return visitColumn(node, ctx); + case HiveParser.TOK_FILEFORMAT_GENERIC: + return visitFileFormatGeneric(node, ctx); + case HiveParser.TOK_TABLEFILEFORMAT: + return visitTableFileFormat(node, ctx); + case HiveParser.TOK_TABLESERIALIZER: + return visitTableSerializer(node, ctx); + case HiveParser.TOK_SERDENAME: + return visitSerdeName(node, ctx); + case HiveParser.TOK_TABLEROWFORMAT: + return visitTableRowFormat(node, ctx); + case HiveParser.TOK_SERDEPROPS: + return visitSerdeProps(node, ctx); + case HiveParser.TOK_TABLEROWFORMATFIELD: + return visitTableRowFormatField(node, ctx); default: // return visitChildren(node, ctx); throw new UnhandledASTTokenException(node); } } + protected R visitTableRowFormatField(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitSerdeProps(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitTableRowFormat(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitSerdeName(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitTableSerializer(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitTableFileFormat(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitFileFormatGeneric(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitColumn(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitColumnList(ASTNode node, C ctx) { + return visitChildren(node, ctx).get(0); + } + + protected R visitIfNotExists(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitLikeTable(ASTNode node, C ctx) { + if (node.getChildren() != null) { + return visitChildren(node, ctx).get(0); + } + return null; + } + + protected R visitCreateTable(ASTNode node, C ctx) { + return visitChildren(node, ctx).get(0); + } + protected R visitKeywordLiteral(ASTNode node, C ctx) { return visitChildren(node, ctx).get(0); } diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java index 2eda31f4e..750b298f0 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilder.java @@ -6,6 +6,7 @@ package com.linkedin.coral.hive.hive2rel.parsetree; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -20,6 +21,7 @@ import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlBasicTypeNameSpec; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCharStringLiteral; import org.apache.calcite.sql.SqlDataTypeSpec; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlIntervalQualifier; @@ -43,6 +45,7 @@ import com.linkedin.coral.com.google.common.collect.ImmutableList; import com.linkedin.coral.com.google.common.collect.Iterables; +import com.linkedin.coral.common.calcite.sql.util.SqlDdlNodes; import com.linkedin.coral.common.functions.CoralSqlUnnestOperator; import com.linkedin.coral.common.functions.Function; import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator; @@ -608,6 +611,96 @@ protected SqlNode visitSelect(ASTNode node, ParseContext ctx) { return ctx.selects; } + @Override + protected SqlNode visitCreateTable(ASTNode node, ParseContext ctx) { + CreateTableOptions ctOptions = new CreateTableOptions(); + for (Node child : node.getChildren()) { + ASTNode ast = (ASTNode) child; + switch (ast.getType()) { + case HiveParser.TOK_TABNAME: + ctOptions.name = (SqlIdentifier) visitTabnameNode(ast, ctx); + break; + case HiveParser.TOK_IFNOTEXISTS: + ctOptions.ifNotExists = true; + break; + case HiveParser.TOK_TABCOLLIST: + ctOptions.columnList = (SqlNodeList) visitColumnList(ast, ctx); + break; + case HiveParser.TOK_QUERY: + ctOptions.query = visitQueryNode(ast, ctx); + break; + case HiveParser.TOK_TABLESERIALIZER: + ctOptions.tableSerializer = visitTableSerializer(ast, ctx); + break; + case HiveParser.TOK_TABLEFILEFORMAT: + ctOptions.tableFileFormat = (SqlNodeList) visitTableFileFormat(ast, ctx); + break; + case HiveParser.TOK_FILEFORMAT_GENERIC: + ctOptions.tableFileFormat = (SqlNodeList) visitFileFormatGeneric(ast, ctx); + break; + case HiveParser.TOK_TABLEROWFORMAT: + ctOptions.tableRowFormat = (SqlCharStringLiteral) visitTableRowFormat(ast, ctx); + break; + default: + break; + } + } + return SqlDdlNodes.createTable(ZERO, false, ctOptions.ifNotExists, ctOptions.name, ctOptions.columnList, + ctOptions.query, ctOptions.tableSerializer, ctOptions.tableFileFormat, ctOptions.tableRowFormat); + } + + @Override + protected SqlNode visitColumnList(ASTNode node, ParseContext ctx) { + List sqlNodeList = visitChildren(node, ctx); + return new SqlNodeList(sqlNodeList, ZERO); + } + + @Override + protected SqlNode visitColumn(ASTNode node, ParseContext ctx) { + return visitChildren(node, ctx).get(0); + } + + @Override + protected SqlNode visitIfNotExists(ASTNode node, ParseContext ctx) { + return SqlLiteral.createBoolean(true, ZERO); + } + + @Override + protected SqlNode visitTableRowFormat(ASTNode node, ParseContext ctx) { + return visitChildren(node, ctx).get(0); + } + + @Override + protected SqlNode visitSerdeName(ASTNode node, ParseContext ctx) { + return visit((ASTNode) node.getChildren().get(0), ctx); + } + + @Override + protected SqlNode visitTableSerializer(ASTNode node, ParseContext ctx) { + return visitChildren(node, ctx).get(0); + } + + @Override + protected SqlNode visitTableFileFormat(ASTNode node, ParseContext ctx) { + List sqlNodeList = visitChildren(node, ctx); + return new SqlNodeList(sqlNodeList, ZERO); + } + + @Override + protected SqlNode visitFileFormatGeneric(ASTNode node, ParseContext ctx) { + return new SqlNodeList(Arrays.asList(visitChildren(node, ctx).get(0)), ZERO); + } + + @Override + protected SqlNode visitSerdeProps(ASTNode node, ParseContext ctx) { + return visitChildren(node, ctx).get(0); + } + + @Override + protected SqlNode visitTableRowFormatField(ASTNode node, ParseContext ctx) { + return visitChildren(node, ctx).get(0); + } + @Override protected SqlNode visitTabRefNode(ASTNode node, ParseContext ctx) { List sqlNodes = visitChildren(node, ctx); @@ -1059,4 +1152,14 @@ Optional getHiveTable() { return hiveTable; } } + + static class CreateTableOptions { + SqlIdentifier name; + SqlNodeList columnList; + SqlNode query; + boolean ifNotExists; + SqlNode tableSerializer; + SqlNodeList tableFileFormat; + SqlCharStringLiteral tableRowFormat; + } } diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/validators/HiveDdlSqlValidator.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/validators/HiveDdlSqlValidator.java new file mode 100644 index 000000000..285f19bf1 --- /dev/null +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/validators/HiveDdlSqlValidator.java @@ -0,0 +1,25 @@ +/** + * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.hive.hive2rel.validators; + +import org.apache.calcite.sql.SqlNode; + +import com.linkedin.coral.common.calcite.DdlSqlValidator; + + +public class HiveDdlSqlValidator implements DdlSqlValidator { + @Override + public void validate(SqlNode ddlSqlNode) { + switch (ddlSqlNode.getKind()) { + case CREATE_TABLE: + validateCreateTable(ddlSqlNode); + } + } + + private void validateCreateTable(SqlNode sqlNode) { + //Todo need to add appropriate validations + } +} diff --git a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java index e332a0d3d..b94e4d9ed 100644 --- a/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java +++ b/coral-hive/src/test/java/com/linkedin/coral/hive/hive2rel/parsetree/ParseTreeBuilderTest.java @@ -147,6 +147,7 @@ public Iterator getConvertSql() { "SELECT LAST_VALUE(c) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS min_c FROM foo", "SELECT STDDEV(c) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) AS min_c FROM foo", "SELECT VARIANCE(c) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS min_c FROM foo"); + // We wrap the SQL to be tested here rather than wrap each SQL statement in the its own array in the constant return convertSql.stream().map(x -> new Object[] { x }).iterator(); } @@ -193,7 +194,28 @@ public Iterator getValidateSql() { "SELECT CASE WHEN `a` THEN 10 WHEN `b` THEN 20 ELSE 30 END FROM `foo`"), ImmutableList.of("SELECT named_struct('abc', 123, 'def', 234.23) FROM foo", "SELECT `named_struct`('abc', 123, 'def', 234.23) FROM `foo`"), - ImmutableList.of("SELECT 0L FROM foo", "SELECT 0 FROM `foo`")); + ImmutableList.of("SELECT 0L FROM foo", "SELECT 0 FROM `foo`"), + + //Basic CTAS query + ImmutableList.of("CREATE TABLE sample AS select * from tmp", "CREATE TABLE `sample` AS select * from `tmp`"), + //CTAS query with IF NOT EXISTS keyword + ImmutableList.of("CREATE TABLE IF NOT EXISTS sample AS SELECT * FROM tmp", + "CREATE TABLE IF NOT EXISTS `sample` AS select * from `tmp`"), + //CTAS query with storage format + ImmutableList.of("CREATE TABLE sample STORED AS ORC AS SELECT * FROM tmp", + "CREATE TABLE `sample` STORED AS `ORC` AS select * from `tmp`"), + //CTAS query with input and output formats + ImmutableList.of( + "CREATE TABLE sample STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM tmp", + "CREATE TABLE `sample` STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM `tmp`"), + //CTAS query with serde + ImmutableList.of( + "CREATE TABLE sample ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM tmp", + "CREATE TABLE `sample` ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM `tmp`"), + //CTAS query with wow format delimiter fields + ImmutableList.of( + "CREATE TABLE sample ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM tmp", + "CREATE TABLE `sample` ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM `tmp`")); return convertAndValidateSql.stream().map(x -> new Object[] { x.get(0), x.get(1) }).iterator(); } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/TranslationUtils.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/TranslationUtils.java index 1828bcd25..a12f35ea5 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/TranslationUtils.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/TranslationUtils.java @@ -6,7 +6,9 @@ package com.linkedin.coral.coralservice.utils; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.SqlNode; +import com.linkedin.coral.common.calcite.sql.SqlCommand; import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; import com.linkedin.coral.spark.CoralSpark; import com.linkedin.coral.trino.rel2trino.RelToTrinoConverter; @@ -29,8 +31,21 @@ public static String translateHiveToTrino(String query) { } public static String translateHiveToSpark(String query) { - RelNode relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query); - CoralSpark coralSpark = CoralSpark.create(relNode); - return coralSpark.getSparkSql(); + HiveToRelConverter hiveToRelConverter = new HiveToRelConverter(hiveMetastoreClient); + SqlNode sqlNode = hiveToRelConverter.toSqlNode(query); + if (sqlNode instanceof SqlCommand) { + SqlNode selectNode = ((SqlCommand) sqlNode).getSelectQuery(); + SqlNode selectSparkNode = convertHiveSqlNodeToCoralNode(hiveToRelConverter, selectNode); + ((SqlCommand) sqlNode).setSelectQuery(selectSparkNode); + } else { + sqlNode = convertHiveSqlNodeToCoralNode(hiveToRelConverter, sqlNode); + } + return CoralSpark.constructSparkSQL(sqlNode); + } + + private static SqlNode convertHiveSqlNodeToCoralNode(HiveToRelConverter hiveToRelConverter, SqlNode sqlNode) { + RelNode relNode = hiveToRelConverter.toRel(sqlNode); + SqlNode coralSqlNode = CoralSpark.getCoralSqlNode(relNode); + return coralSqlNode; } } diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java index 056edb535..bdae06085 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java @@ -70,6 +70,22 @@ public static CoralSpark create(RelNode irRelNode) { return new CoralSpark(baseTables, sparkUDFInfos, sparkSQL); } + /** + * Users use this function to get CoralSqlNode from CoralRelNode + * This should be used when user need to get CoralSqlNode from CoralRelNode by applying + * spark specific transformations on CoralRelNode + * with Coral-schema output schema + * + * @return [[SqlNode]] + */ + public static SqlNode getCoralSqlNode(RelNode irRelNode) { + SparkRelInfo sparkRelInfo = IRRelToSparkRelTransformer.transform(irRelNode); + RelNode sparkRelNode = sparkRelInfo.getSparkRelNode(); + CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter(); + SqlNode coralSqlNode = rel2sql.convert(sparkRelNode); + return coralSqlNode; + } + /** * Users use this function as the main API for getting CoralSpark instance. * This should be used when user need to align the Coral-spark translated SQL @@ -115,6 +131,24 @@ private static String constructSparkSQL(RelNode sparkRelNode) { return rewrittenSparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql(); } + /** + * This function returns a completely expanded SQL statement in Spark SQL Dialect. + * + * A SQL statement is 'completely expanded' if it doesn't depend + * on (or selects from) Hive views, but instead, just on base tables. + + * Converts CoralSqlNode to Spark SQL + * + * @param sqlNode CoralSqlNode which will be translated to SparkSql + * + * @return SQL String in Spark SQL dialect which is 'completely expanded' + */ + public static String constructSparkSQL(SqlNode sqlNode) { + SqlNode sparkSqlNode = sqlNode.accept(new CoralSqlNodeToSparkSqlNodeConverter()); + SqlNode rewrittenSparkSqlNode = sparkSqlNode.accept(new SparkSqlRewriter()); + return rewrittenSparkSqlNode.toSqlString(SparkSqlDialect.INSTANCE).getSql(); + } + private static String constructSparkSQLWithExplicitAlias(RelNode sparkRelNode, List aliases) { CoralRelToSqlNodeConverter rel2sql = new CoralRelToSqlNodeConverter(); // Create temporary objects r and rewritten to make debugging easier diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index b2431f774..c5cde16b5 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -14,6 +14,7 @@ import org.apache.avro.Schema; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; @@ -26,6 +27,8 @@ import org.testng.annotations.Test; import com.linkedin.coral.com.google.common.collect.ImmutableList; +import com.linkedin.coral.common.calcite.sql.SqlCommand; +import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; import com.linkedin.coral.spark.containers.SparkUDFInfo; import com.linkedin.coral.spark.exceptions.UnsupportedUDFException; @@ -702,6 +705,32 @@ public void testCastDecimal() { assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql); } + @Test + public void testCreateTableAsSelectWithUnionExtractUDF() { + String query = "CREATE TABLE foo_bar as SELECT extract_union(foo) from union_table"; + String targetSql = "CREATE TABLE foo_bar as SELECT coalesce_struct(foo) FROM default.union_table"; + assertEquals(translateHiveToSpark(query).toLowerCase().replaceAll("\n", " "), + targetSql.toLowerCase().replaceAll("\n", " ")); + } + + @Test + public void testCreateTableAsSelect() { + String query = "CREATE TABLE foo_bar as SELECT CAST(a AS DECIMAL(10, 0)) casted_decimal FROM default.foo"; + String targetSql = "CREATE TABLE foo_bar as SELECT CAST(a AS DECIMAL(10, 0)) casted_decimal FROM default.foo"; + assertEquals(translateHiveToSpark(query).toLowerCase().replaceAll("\n", " "), + targetSql.toLowerCase().replaceAll("\n", " ")); + } + + @Test + public void testCreateTableAsSelectWithTableProperties() { + String query = + "CREATE TABLE sample ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM default.foo"; + String targetSql = + "CREATE TABLE sample ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleInputFormat' OUTPUTFORMAT 'com.ly.spark.example.serde.io.SerDeExampleOutputFormat' AS SELECT * FROM default.foo"; + assertEquals(translateHiveToSpark(query).toLowerCase().replaceAll("\n", " "), + targetSql.toLowerCase().replaceAll("\n", " ")); + } + @Test public void testCastDecimalDefault() { RelNode relNode = TestUtils.toRelNode("SELECT CAST(a as DECIMAL) as casted_decimal FROM default.foo"); @@ -856,4 +885,22 @@ private static String getCoralSparkTranslatedSqlWithAliasFromCoralSchema(String return coralSpark.getSparkSql(); } + private static SqlNode convertHiveSqlNodeToCoralNode(HiveToRelConverter hiveToRelConverter, SqlNode sqlNode) { + RelNode relNode = hiveToRelConverter.toRel(sqlNode); + SqlNode coralSqlNode = CoralSpark.getCoralSqlNode(relNode); + return coralSqlNode; + } + + private static String translateHiveToSpark(String query) { + HiveToRelConverter hiveToRelConverter = TestUtils.hiveToRelConverter; + SqlNode sqlNode = hiveToRelConverter.toSqlNode(query); + if (sqlNode instanceof SqlCommand) { + SqlNode selectNode = ((SqlCommand) sqlNode).getSelectQuery(); + SqlNode selectSparkNode = convertHiveSqlNodeToCoralNode(hiveToRelConverter, selectNode); + ((SqlCommand) sqlNode).setSelectQuery(selectSparkNode); + } else { + sqlNode = convertHiveSqlNodeToCoralNode(hiveToRelConverter, sqlNode); + } + return CoralSpark.constructSparkSQL(sqlNode); + } }