diff --git a/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeGenerationTransformer.java b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeGenerationTransformer.java new file mode 100644 index 000000000..d97bff411 --- /dev/null +++ b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeGenerationTransformer.java @@ -0,0 +1,668 @@ +/** + * Copyright 2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.incremental; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.calcite.plan.RelOptSchema; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.prepare.RelOptTableImpl; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.RelShuttleImpl; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalTableScan; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; + + +public class RelNodeGenerationTransformer { + private final String TABLE_NAME_PREFIX = "Table#"; + private final String DELTA_SUFFIX = "_delta"; + + private final String PREV_SUFFIX = "_prev"; + + private RelOptSchema relOptSchema; + private Map snapshotRelNodes; + private Map deltaRelNodes; + private RelNode tempLastRelNode; + private Set needsProj; + + public RelNodeGenerationTransformer() { + relOptSchema = null; + snapshotRelNodes = new LinkedHashMap<>(); + deltaRelNodes = new LinkedHashMap<>(); + tempLastRelNode = null; + needsProj = new HashSet<>(); + } + + /** + * Generates incremental RelNodes for the given RelNode. The incremental RelNodes are generated by: + * - Identifying the LogicalJoin nodes that may need a projection and adding them to the needsProj set. + * - Uniformly formatting the RelNode by recursively processing its nodes. + * - When converting the RelNode into its incremental version, populating snapshotRelNodes and deltaRelNodes with the generated RelNodes. + * - Generating a list of lists of RelNodes that represent the incremental RelNodes in different combinations. + *

+ * @param relNode input RelNode to generate incremental RelNodes for + * @return a list of lists of RelNodes that represent the incremental RelNodes in different combinations + */ + public List> generateIncrementalRelNodes(RelNode relNode) { + findJoinNeedsProject(relNode); + relNode = uniformFormat(relNode); + convertRelIncremental(relNode); + Map snapshotRelNodes = getSnapshotRelNodes(); + Map deltaRelNodes = getDeltaRelNodes(); + List> combinedLists = generateCombinedLists(deltaRelNodes, snapshotRelNodes); + return combinedLists; + } + + /** + * Generates a list of lists of RelNodes that represent the incremental RelNodes in different combinations. + * The formula used to generate the combinations is as follows: + * - For n subquery, there are n combinations. + * - For each combination, the first i tables are delta tables and the rest are snapshot tables. + * - The combinations are generated by iterating over the deltaRelNodes and snapshotRelNodes maps and adding the delta + * tables to the combination until the index i is reached, and then adding the snapshot tables to the combination. + * That means each generated plan would be a combination of incremental plan and batch plan, consisting of + * a List of RelNodes, denoting each sub-query will be incremental executed or batch executed. + *

+ * Take the following three-tables Join as an example (table names are enclosed in parentheses of the RelNodes): + *

+   *            LogicalProject#8
+   *                  |
+   *            LogicalJoin#7
+   *             /        \
+   *    LogicalProject#4   TableScan#5
+   *            |
+   *      LogicalJoin#3
+   *          /        \
+   * TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * + * LogicalProject#4 and LogicalProject#8 are two sub-queries, and each sub-query will be materialized and replaced with a TableScan. + * LogicalProject#4 will be replaced with Table0 and LogicalProject#8 will be replaced with Table1 like below. + *
+   *    LogicalProject#4
+   *            |
+   *      LogicalJoin#3            =>          TableScan(Table0)
+   *          /        \
+   * TableScan(Table_A)  TableScan(Table_B)
+   * 
+ *
+   *            LogicalProject#8
+   *                  |
+   *            LogicalJoin#7
+   *             /        \
+   *    LogicalProject#4   TableScan#5
+   *            |
+   *      LogicalJoin#3
+   *          /         \
+   * TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * will be replaced with + *
+   *            LogicalProject#8
+   *                  |
+   *            LogicalJoin#7      =>        TableScan(Table1)
+   *             /        \
+   *    TableScan(Table0)   TableScan#5
+   * 
+ * + *

+ * There will be 3 combinations: + *

+ * Incremental: [Table0_delta, Table1_delta], which means both joins are executed incrementally. + *

+ * Part-Batch, Part-Incremental: [Table0_delta, Table1], which means The first join is executed incrementally, and the second join is executed in batch mode. + *

+ * Batch: [Table0, Table1], which means both joins are executed in batch mode. + *

+ * @param deltaRelNodes map of delta RelNodes + * @param snapshotRelNodes map of snapshot RelNodes + * @return a list of lists of RelNodes that represent the incremental RelNodes in different combinations + */ + private List> generateCombinedLists(Map deltaRelNodes, + Map snapshotRelNodes) { + List> resultList = new ArrayList<>(); + assert (deltaRelNodes.size() == snapshotRelNodes.size()); + int n = deltaRelNodes.size(); + + for (int i = -1; i < n; i++) { + List tempList = new ArrayList<>(); + for (int j = 0; j < n; j++) { + if (j <= i) { + tempList.add(deltaRelNodes.get("Table#" + j + "_delta")); + } else { + tempList.add(snapshotRelNodes.get("Table#" + j)); + } + } + + resultList.add(tempList); + } + + return resultList; + } + + /** + * Returns snapshotRelNodes with deterministic keys. + */ + public Map getSnapshotRelNodes() { + Map deterministicSnapshotRelNodes = new LinkedHashMap<>(); + for (String description : snapshotRelNodes.keySet()) { + deterministicSnapshotRelNodes.put(getDeterministicDescriptionFromDescription(description, false), + snapshotRelNodes.get(description)); + } + return deterministicSnapshotRelNodes; + } + + /** + * Returns deltaRelNodes with deterministic keys. + */ + public Map getDeltaRelNodes() { + Map deterministicDeltaRelNodes = new LinkedHashMap<>(); + for (String description : deltaRelNodes.keySet()) { + deterministicDeltaRelNodes.put(getDeterministicDescriptionFromDescription(description, true), + deltaRelNodes.get(description)); + } + return deterministicDeltaRelNodes; + } + + /** + * Traverses the relational algebra tree starting from the given RelNode. + * Identifies LogicalJoin nodes that needs a projection and adds them to the needsProj set. + * The traversal uses a custom RelShuttleImpl visitor that: + * - Recursively processing when the RelNode is not a LogicalProject and check its inputs. + * - If one input node is LogicalJoin, the LogicalJoin node is added to the needsProj set. + *

+ * For example, consider the following queries for a two-tables Join + *

+ * Input1: + *

+   *               LogicalProject
+   *                     |
+   *               LogicalJoin
+   *                 /        \
+   *         LogicalJoin(*)    TableScan(Table_C)
+   *           /          \
+   *  TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * LogicalJoin(*) is a LogicalJoin node that doesn't have a LogicalProject parent, so it is added to the needsProj set. + *

+ * Input2: + *

+   *               LogicalProject
+   *                     |
+   *               LogicalJoin
+   *                 /        \
+   *      LogicalProject  TableScan(Table_C)
+   *                |
+   *          LogicalJoin
+   *          /          \
+   *  TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * In this case, all LogicalJoin nodes have a LogicalProject parent, so no one is added to the needsProj set. + *

+ * @param relNode input RelNode to traverse + */ + private void findJoinNeedsProject(RelNode relNode) { + RelShuttle converter = new RelShuttleImpl() { + + @Override + public RelNode visit(LogicalJoin join) { + RelNode left = join.getLeft(); + RelNode right = join.getRight(); + if (left instanceof LogicalJoin) { + needsProj.add(left); + } + if (right instanceof LogicalJoin) { + needsProj.add(right); + } + + findJoinNeedsProject(left); + findJoinNeedsProject(right); + + return join; + } + + @Override + public RelNode visit(LogicalFilter filter) { + if (filter.getInput() instanceof LogicalJoin) { + needsProj.add(filter.getInput()); + } + findJoinNeedsProject(filter.getInput()); + + return filter; + } + + @Override + public RelNode visit(LogicalProject project) { + findJoinNeedsProject(project.getInput()); + return project; + } + + @Override + public RelNode visit(LogicalUnion union) { + List children = union.getInputs(); + for (RelNode child : children) { + if (child instanceof LogicalJoin) { + needsProj.add(child); + } + findJoinNeedsProject(child); + } + + return union; + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + if (aggregate.getInput() instanceof LogicalJoin) { + needsProj.add(aggregate.getInput()); + } + findJoinNeedsProject(aggregate.getInput()); + + return aggregate; + } + }; + relNode.accept(converter); + } + + /** + * Converts the given relational algebra tree into its "previous" version by modifying TableScan nodes + * and transforming the structure of other relational nodes (such as LogicalJoin, LogicalFilter, LogicalProject, + * LogicalUnion, and LogicalAggregate). + * Specifically: + * - TableScan nodes are modified to point to a "_prev" version of the table. + * - Other RelNodes are recursively transformed to operate on their "previous" versions of their inputs. + *

+ * For example the following query for a two tables Join (table names are enclosed in parentheses of the RelNodes): + * Input: + *

+   *            LogicalProject
+   *                  |
+   *            LogicalJoin
+   *             /        \
+   *   TableScan(Table_A)    TableScan(Table_B)
+   * 
+ * + * Output: + *
+   *            LogicalProject
+   *                  |
+   *            LogicalJoin
+   *             /        \
+   * TableScan(Table_A_prev)     TableScan(Table_B_prev)
+   * 
+ * In SQL view, the transformation is: + *

+ * {@code + * SELECT * FROM test.bar1 JOIN test.bar2 ON test.bar1.x = test.bar2.x} + *

+ * to + *

+ * {@code + * SELECT * FROM test.bar1_prev JOIN test.bar2_prev ON test.bar1_prev.x = test.bar2_prev.x} + * + *

+ * @param originalNode input RelNode to transform + *

+ */ + public RelNode convertRelPrev(RelNode originalNode) { + RelShuttle converter = new RelShuttleImpl() { + @Override + public RelNode visit(TableScan scan) { + RelOptTable originalTable = scan.getTable(); + List incrementalNames = new ArrayList<>(originalTable.getQualifiedName()); + String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + PREV_SUFFIX; + incrementalNames.add(deltaTableName); + RelOptTable incrementalTable = + RelOptTableImpl.create(originalTable.getRelOptSchema(), originalTable.getRowType(), incrementalNames, null); + return LogicalTableScan.create(scan.getCluster(), incrementalTable); + } + + @Override + public RelNode visit(LogicalJoin join) { + RelNode left = join.getLeft(); + RelNode right = join.getRight(); + RelNode prevLeft = convertRelPrev(left); + RelNode prevRight = convertRelPrev(right); + RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + + LogicalProject p3 = createProjectOverJoin(join, prevLeft, prevRight, rexBuilder); + + return p3; + } + + @Override + public RelNode visit(LogicalFilter filter) { + RelNode transformedChild = convertRelPrev(filter.getInput()); + + return LogicalFilter.create(transformedChild, filter.getCondition()); + } + + @Override + public RelNode visit(LogicalProject project) { + RelNode transformedChild = convertRelPrev(project.getInput()); + return LogicalProject.create(transformedChild, project.getProjects(), project.getRowType()); + } + + @Override + public RelNode visit(LogicalUnion union) { + List children = union.getInputs(); + List transformedChildren = + children.stream().map(child -> convertRelPrev(child)).collect(Collectors.toList()); + return LogicalUnion.create(transformedChildren, union.all); + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + RelNode transformedChild = convertRelPrev(aggregate.getInput()); + return LogicalAggregate.create(transformedChild, aggregate.getGroupSet(), aggregate.getGroupSets(), + aggregate.getAggCallList()); + } + }; + return originalNode.accept(converter); + } + + /** + * Transforms the given relational algebra tree to ensure a uniform format that each LogicalJoin has a LogicalProject as its parent. + * This transformation involves: + * - For LogicalJoin nodes: recursively processing their children, and optionally creating a projection over the join + * if the join is in the needsProj set. (when the Join don't have a LogicalProject as its parent) + * - For other type RelNode: recursively processing its inputs, and using the transformed children as its new inputs. + *

+ * Here is an example of how the uniformFormat method works for a three-tables join query, (table names are enclosed in parentheses of the RelNodes): + *

+ * Input: + *

+   *               LogicalProject
+   *                     |
+   *               LogicalJoin
+   *                 /        \
+   *           LogicalJoin    TableScan(Table_C)
+   *           /          \
+   *  TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * + * + * Output: + *
+   *               LogicalProject
+   *                     |
+   *               LogicalJoin
+   *                 /        \
+   *      LogicalProject(*)   TableScan(Table_C)
+   *                |
+   *          LogicalJoin
+   *          /          \
+   *  TableScan(Table_A)  TableScan(Table_B)
+   * 
+ * The LogicalProject(*) is added to ensure the uniform format. + * @param originalNode input RelNode to transform + */ + private RelNode uniformFormat(RelNode originalNode) { + RelShuttle converter = new RelShuttleImpl() { + + @Override + public RelNode visit(LogicalJoin join) { + RelNode left = join.getLeft(); + RelNode right = join.getRight(); + RelNode uniLeft = uniformFormat(left); + RelNode uniRight = uniformFormat(right); + RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + if (needsProj.contains(join)) { + LogicalProject p1 = createProjectOverJoin(join, uniLeft, uniRight, rexBuilder); + return p1; + } + return LogicalJoin.create(uniLeft, uniRight, join.getCondition(), join.getVariablesSet(), join.getJoinType()); + } + + @Override + public RelNode visit(LogicalFilter filter) { + RelNode transformedChild = uniformFormat(filter.getInput()); + + return LogicalFilter.create(transformedChild, filter.getCondition()); + } + + @Override + public RelNode visit(LogicalProject project) { + RelNode transformedChild = uniformFormat(project.getInput()); + return LogicalProject.create(transformedChild, project.getProjects(), project.getRowType()); + } + + @Override + public RelNode visit(LogicalUnion union) { + List children = union.getInputs(); + List transformedChildren = + children.stream().map(child -> uniformFormat(child)).collect(Collectors.toList()); + return LogicalUnion.create(transformedChildren, union.all); + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + RelNode transformedChild = uniformFormat(aggregate.getInput()); + return LogicalAggregate.create(transformedChild, aggregate.getGroupSet(), aggregate.getGroupSets(), + aggregate.getAggCallList()); + } + }; + return originalNode.accept(converter); + } + + /** + * Converts an input {@link RelNode} to its incremental version. This method traverses the input + * {@link RelNode} tree and transforms it by creating incremental version of the node. + * It also populates the {@code snapshotRelNodes} and {@code deltaRelNodes} collections, which + * are used to keep track of materialized views and incremental transformations, respectively. + * {@code snapshotRelNodes} and {@code deltaRelNodes} are used to generate the all incremental + * combinations of the input {@link RelNode}. + * + *

+ * The {@code snapshotRelNodes} collection is populated with the materialized versions of the + * nodes, which are intermediate snapshots of the original nodes. The {@code deltaRelNodes} + * collection is populated with the incremental versions of the nodes, which represent the + * deltas that need to be applied to the original nodes. + *

+ * + *

+ * The method specifically adds new elements to {@code snapshotRelNodes} and {@code deltaRelNodes} + * during the transformation of {@link LogicalProject} nodes: + *

+ *
    + *
  • {@code snapshotRelNodes}: This collection is updated with the original {@link LogicalProject} + * node or its materialized version, which is obtained from {@code getTempLastRelNode()} if available.
  • + *
  • {@code deltaRelNodes}: This collection is updated with the incremental version of the {@link LogicalProject} + * node, which is created by transforming the child node and preserving the project's expressions and row type.
  • + *
+ * + * @param originalNode the input {@link RelNode} to generate an incremental version for. + * @return the incremental version of the input {@link RelNode}. + */ + public RelNode convertRelIncremental(RelNode originalNode) { + RelShuttle converter = new RelShuttleImpl() { + @Override + public RelNode visit(TableScan scan) { + RelOptTable originalTable = scan.getTable(); + + // Set RelNodeIncrementalTransformer class relOptSchema if not already set + if (relOptSchema == null) { + relOptSchema = originalTable.getRelOptSchema(); + } + + // Create delta scan + List incrementalNames = new ArrayList<>(originalTable.getQualifiedName()); + String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + DELTA_SUFFIX; + incrementalNames.add(deltaTableName); + RelOptTable incrementalTable = + RelOptTableImpl.create(originalTable.getRelOptSchema(), originalTable.getRowType(), incrementalNames, null); + return LogicalTableScan.create(scan.getCluster(), incrementalTable); + } + + @Override + public RelNode visit(LogicalJoin join) { + RelNode left = join.getLeft(); + RelNode right = join.getRight(); + RelNode incrementalLeft = convertRelIncremental(left); + RelNode incrementalRight = convertRelIncremental(right); + + RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + + // Check if we can replace the left and right nodes with a scan of a materialized table + String leftDescription = getDescriptionFromRelNode(left, false); + String leftIncrementalDescription = getDescriptionFromRelNode(left, true); + if (snapshotRelNodes.containsKey(leftDescription)) { + left = + susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(leftDescription, false), left); + incrementalLeft = susbstituteWithMaterializedView( + getDeterministicDescriptionFromDescription(leftIncrementalDescription, true), incrementalLeft); + } + String rightDescription = getDescriptionFromRelNode(right, false); + String rightIncrementalDescription = getDescriptionFromRelNode(right, true); + if (snapshotRelNodes.containsKey(rightDescription)) { + right = susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(rightDescription, false), + right); + incrementalRight = susbstituteWithMaterializedView( + getDeterministicDescriptionFromDescription(rightIncrementalDescription, true), incrementalRight); + } + RelNode prevLeft = convertRelPrev(left); + RelNode prevRight = convertRelPrev(right); + + // We need to do this in the join to get potentially updated left and right nodes + tempLastRelNode = createProjectOverJoin(join, left, right, rexBuilder); + + LogicalProject p1 = createProjectOverJoin(join, prevLeft, incrementalRight, rexBuilder); + LogicalProject p2 = createProjectOverJoin(join, incrementalLeft, prevRight, rexBuilder); + LogicalProject p3 = createProjectOverJoin(join, incrementalLeft, incrementalRight, rexBuilder); + + LogicalUnion unionAllJoins = + LogicalUnion.create(Arrays.asList(LogicalUnion.create(Arrays.asList(p1, p2), true), p3), true); + + return unionAllJoins; + } + + @Override + public RelNode visit(LogicalFilter filter) { + RelNode transformedChild = convertRelIncremental(filter.getInput()); + return LogicalFilter.create(transformedChild, filter.getCondition()); + } + + @Override + public RelNode visit(LogicalProject project) { + RelNode transformedChild = convertRelIncremental(project.getInput()); + RelNode materializedProject = getTempLastRelNode(); + if (materializedProject != null) { + snapshotRelNodes.put(getDescriptionFromRelNode(project, false), materializedProject); + } else { + snapshotRelNodes.put(getDescriptionFromRelNode(project, false), project); + } + LogicalProject transformedProject = + LogicalProject.create(transformedChild, project.getProjects(), project.getRowType()); + deltaRelNodes.put(getDescriptionFromRelNode(project, true), transformedProject); + return transformedProject; + } + + @Override + public RelNode visit(LogicalUnion union) { + List children = union.getInputs(); + List transformedChildren = + children.stream().map(child -> convertRelIncremental(child)).collect(Collectors.toList()); + return LogicalUnion.create(transformedChildren, union.all); + } + + @Override + public RelNode visit(LogicalAggregate aggregate) { + RelNode transformedChild = convertRelIncremental(aggregate.getInput()); + return LogicalAggregate.create(transformedChild, aggregate.getGroupSet(), aggregate.getGroupSets(), + aggregate.getAggCallList()); + } + }; + return originalNode.accept(converter); + } + + /** + * Returns the tempLastRelNode and sets the variable back to null. Should only be called once for each retrieval + * instance since subsequent consecutive calls will yield null. + */ + private RelNode getTempLastRelNode() { + RelNode currentTempLastRelNode = tempLastRelNode; + tempLastRelNode = null; + return currentTempLastRelNode; + } + + /** + * Returns the corresponding description for a given RelNode by extracting the identifier (ex. the identifier for + * LogicalProject#22 is 22) and prepending the TABLE_NAME_PREFIX. Depending on the delta value, a delta suffix may be + * appended. + * @param relNode RelNode from which the identifier will be retrieved. + * @param delta configure whether to get the delta name + */ + private String getDescriptionFromRelNode(RelNode relNode, boolean delta) { + String identifier = relNode.getDescription().split("#")[1]; + String description = TABLE_NAME_PREFIX + identifier; + if (delta) { + return description + DELTA_SUFFIX; + } + return description; + } + + /** + * Returns a description based on mapping index order that will stay the same across different runs of the same + * query. The description consists of the table prefix, the index, and optionally, the delta suffix. + * @param description output from calling getDescriptionFromRelNode() + * @param delta configure whether to get the delta name + */ + private String getDeterministicDescriptionFromDescription(String description, boolean delta) { + if (delta) { + List deltaKeyOrdering = new ArrayList<>(deltaRelNodes.keySet()); + return TABLE_NAME_PREFIX + deltaKeyOrdering.indexOf(description) + DELTA_SUFFIX; + } else { + List snapshotKeyOrdering = new ArrayList<>(snapshotRelNodes.keySet()); + return TABLE_NAME_PREFIX + snapshotKeyOrdering.indexOf(description); + } + } + + /** + * Accepts a table name and RelNode and creates a TableScan over the RelNode using the class relOptSchema. + * @param relOptTableName table name corresponding to table to scan over + * @param relNode top-level RelNode that will be replaced with the TableScan + */ + private TableScan susbstituteWithMaterializedView(String relOptTableName, RelNode relNode) { + RelOptTable table = + RelOptTableImpl.create(relOptSchema, relNode.getRowType(), Collections.singletonList(relOptTableName), null); + return LogicalTableScan.create(relNode.getCluster(), table); + } + + /** Creates a LogicalProject whose input is an incremental LogicalJoin node that is constructed from a left and right + * RelNode and LogicalJoin. + * @param join LogicalJoin to create the incremental join from + * @param left left RelNode child of the incremental join + * @param right right RelNode child of the incremental join + * @param rexBuilder RexBuilder for LogicalProject creation + */ + private LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right, RexBuilder rexBuilder) { + LogicalJoin incrementalJoin = + LogicalJoin.create(left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType()); + ArrayList projects = new ArrayList<>(); + ArrayList names = new ArrayList<>(); + IntStream.range(0, incrementalJoin.getRowType().getFieldList().size()).forEach(i -> { + projects.add(rexBuilder.makeInputRef(incrementalJoin, i)); + names.add(incrementalJoin.getRowType().getFieldNames().get(i)); + }); + return LogicalProject.create(incrementalJoin, projects, names); + } +} diff --git a/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeGenerationTest.java b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeGenerationTest.java new file mode 100644 index 000000000..42980c0a7 --- /dev/null +++ b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelNodeGenerationTest.java @@ -0,0 +1,199 @@ +/** + * Copyright 2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.incremental; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.SqlNode; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import com.linkedin.coral.transformers.CoralRelToSqlNodeConverter; + +import static com.linkedin.coral.incremental.TestUtils.*; +import static org.testng.Assert.*; + + +public class RelNodeGenerationTest { + private HiveConf conf; + + @BeforeClass + public void beforeClass() throws HiveException, MetaException, IOException { + conf = TestUtils.loadResourceHiveConf(); + TestUtils.initializeViews(conf); + } + + @AfterTest + public void afterClass() throws IOException { + FileUtils.deleteDirectory(new File(conf.get(CORAL_INCREMENTAL_TEST_DIR))); + } + + public String convert(RelNode relNode) { + CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter(); + SqlNode sqlNode = converter.convert(relNode); + return sqlNode.toSqlString(converter.INSTANCE).getSql(); + } + + public void checkAllPlans(String sql, List> expected) { + List> plans = getAllPlans(sql); + assertEquals(plans.size(), expected.size()); + for (int i = 0; i < plans.size(); i++) { + List plan = plans.get(i); + List expectedPlan = expected.get(i); + assertEquals(plan.size(), expectedPlan.size()); + for (int j = 0; j < plan.size(); j++) { + String actual = convert(plan.get(j)); + assertEquals(actual, expectedPlan.get(j)); + } + } + } + + public List> getAllPlans(String sql) { + RelNode originalRelNode = hiveToRelConverter.convertSql(sql); + RelNodeGenerationTransformer transformer = new RelNodeGenerationTransformer(); + return transformer.generateIncrementalRelNodes(originalRelNode); + } + + @Test + public void testSimpleJoinPrev() { + String sql = "SELECT * FROM test.bar1 JOIN test.bar2 ON test.bar1.x = test.bar2.x"; + RelNodeGenerationTransformer transformer = new RelNodeGenerationTransformer(); + RelNode originalRelNode = hiveToRelConverter.convertSql(sql); + RelNode prev = transformer.convertRelPrev(originalRelNode); + String prevSql = convert(prev); + String expected = "SELECT *\n" + "FROM test.bar1_prev AS bar1_prev\n" + + "INNER JOIN test.bar2_prev AS bar2_prev ON bar1_prev.x = bar2_prev.x"; + assertEquals(prevSql, expected); + } + + @Test + public void testSimpleSelectAll() { + String sql = "SELECT * FROM test.foo"; + List incremental = Arrays.asList("SELECT *\n" + "FROM test.foo_delta AS foo_delta"); + List batch = Arrays.asList("SELECT *\n" + "FROM test.foo AS foo"); + List> expected = Arrays.asList(incremental, batch); + checkAllPlans(sql, expected); + } + + @Test + public void testSimpleJoin() { + String sql = "SELECT * FROM test.bar1 JOIN test.bar2 ON test.bar1.x = test.bar2.x"; + String incrementalSql = "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_prev AS bar1_prev\n" + + "INNER JOIN test.bar2_delta AS bar2_delta ON bar1_prev.x = bar2_delta.x\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.bar1_delta AS bar1_delta\n" + + "INNER JOIN test.bar2_prev AS bar2_prev ON bar1_delta.x = bar2_prev.x) AS t\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.bar1_delta AS bar1_delta0\n" + + "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x"; + List incremental = Arrays.asList(incrementalSql); + List batch = + Arrays.asList("SELECT *\n" + "FROM test.bar1 AS bar1\n" + "INNER JOIN test.bar2 AS bar2 ON bar1.x = bar2.x"); + List> expected = Arrays.asList(incremental, batch); + checkAllPlans(sql, expected); + } + + @Test + public void testNestedJoin() { + String nestedJoin = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1"; + String sql = "SELECT a2, g1 FROM (" + nestedJoin + ") AS nj JOIN test.gamma ON nj.a2 = test.gamma.g2"; + String Table0_delta = + "SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha_prev AS alpha_prev\n" + + "INNER JOIN test.beta_delta AS beta_delta ON alpha_prev.a1 = beta_delta.b1\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM test.alpha_delta AS alpha_delta\n" + + "INNER JOIN test.beta_prev AS beta_prev ON alpha_delta.a1 = beta_prev.b1) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM test.alpha_delta AS alpha_delta0\n" + + "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0"; + String Table1_delta = + "SELECT t0.a2, t0.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0_prev AS Table#0_prev\n" + + "INNER JOIN test.gamma_delta AS gamma_delta ON Table#0_prev.a2 = gamma_delta.g2\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n" + + "INNER JOIN test.gamma_prev AS gamma_prev ON Table#0_delta.a2 = gamma_prev.g2) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta0\n" + + "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t0"; + String Table0 = "SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1"; + String Table1 = + "SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2"; + List combined = Arrays.asList(Table0_delta, Table1); + List incremental = Arrays.asList(Table0_delta, Table1_delta); + List batch = Arrays.asList(Table0, Table1); + List> expected = Arrays.asList(batch, combined, incremental); + checkAllPlans(sql, expected); + } + + @Test + public void testThreeTablesJoin() { + String sql = + "SELECT a1, a2, g1 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1 JOIN test.gamma ON test.alpha.a2 = test.gamma.g2"; + String Table0_delta = "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha_prev AS alpha_prev\n" + + "INNER JOIN test.beta_delta AS beta_delta ON alpha_prev.a1 = beta_delta.b1\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.alpha_delta AS alpha_delta\n" + + "INNER JOIN test.beta_prev AS beta_prev ON alpha_delta.a1 = beta_prev.b1) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM test.alpha_delta AS alpha_delta0\n" + + "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1"; + String Table1_delta = + "SELECT t0.a1, t0.a2, t0.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0_prev AS Table#0_prev\n" + + "INNER JOIN test.gamma_delta AS gamma_delta ON Table#0_prev.a2 = gamma_delta.g2\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n" + + "INNER JOIN test.gamma_prev AS gamma_prev ON Table#0_delta.a2 = gamma_prev.g2) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta0\n" + + "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t0"; + String Table0 = "SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1"; + String Table1 = + "SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2"; + List combined = Arrays.asList(Table0_delta, Table1); + List incremental = Arrays.asList(Table0_delta, Table1_delta); + List batch = Arrays.asList(Table0, Table1); + List> expected = Arrays.asList(batch, combined, incremental); + checkAllPlans(sql, expected); + } + + @Test + public void testFourTablesJoin() { + String nestedJoin1 = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1"; + String nestedJoin2 = "SELECT a2, g1 FROM (" + nestedJoin1 + ") AS nj1 JOIN test.gamma ON nj1.a2 = test.gamma.g2"; + String sql = "SELECT g1, e2 FROM (" + nestedJoin2 + ") AS nj2 JOIN test.epsilon ON nj2.g1 = test.epsilon.e1"; + String Table0_delta = + "SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha_prev AS alpha_prev\n" + + "INNER JOIN test.beta_delta AS beta_delta ON alpha_prev.a1 = beta_delta.b1\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM test.alpha_delta AS alpha_delta\n" + + "INNER JOIN test.beta_prev AS beta_prev ON alpha_delta.a1 = beta_prev.b1) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM test.alpha_delta AS alpha_delta0\n" + + "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0"; + String Table1_delta = + "SELECT t0.a2, t0.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0_prev AS Table#0_prev\n" + + "INNER JOIN test.gamma_delta AS gamma_delta ON Table#0_prev.a2 = gamma_delta.g2\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n" + + "INNER JOIN test.gamma_prev AS gamma_prev ON Table#0_delta.a2 = gamma_prev.g2) AS t\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta0\n" + + "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t0"; + String Table2_delta = + "SELECT t0.g1, t0.e2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#1_prev AS Table#1_prev\n" + + "INNER JOIN test.epsilon_delta AS epsilon_delta ON Table#1_prev.g1 = epsilon_delta.e1\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta\n" + + "INNER JOIN test.epsilon_prev AS epsilon_prev ON Table#1_delta.g1 = epsilon_prev.e1) AS t\n" + + "UNION ALL\n" + "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta0\n" + + "INNER JOIN test.epsilon_delta AS epsilon_delta0 ON Table#1_delta0.g1 = epsilon_delta0.e1) AS t0"; + String Table0 = "SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1"; + String Table1 = + "SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2"; + String Table2 = + "SELECT *\n" + "FROM Table#1 AS Table#1\n" + "INNER JOIN test.epsilon AS epsilon ON Table#1.g1 = epsilon.e1"; + List combined1 = Arrays.asList(Table0_delta, Table1, Table2); + List combined2 = Arrays.asList(Table0_delta, Table1_delta, Table2); + List incremental = Arrays.asList(Table0_delta, Table1_delta, Table2_delta); + List batch = Arrays.asList(Table0, Table1, Table2); + List> expected = Arrays.asList(batch, combined1, combined2, incremental); + checkAllPlans(sql, expected); + } +} diff --git a/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java b/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java index 232705ed4..4c0c12faa 100644 --- a/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java +++ b/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -60,6 +60,11 @@ public static void initializeViews(HiveConf conf) throws HiveException, MetaExce run(driver, "CREATE TABLE IF NOT EXISTS test.bar1(x int, y double)"); run(driver, "CREATE TABLE IF NOT EXISTS test.bar2(x int, y double)"); run(driver, "CREATE TABLE IF NOT EXISTS test.bar3(x int, y double)"); + + run(driver, "CREATE TABLE IF NOT EXISTS test.alpha(a1 int, a2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.beta(b1 int, b2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.gamma(g1 int, g2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.epsilon(e1 int, e2 double)"); } public static HiveConf loadResourceHiveConf() { diff --git a/coral-service/frontend/env.local b/coral-service/frontend/env.local new file mode 100644 index 000000000..89293e582 --- /dev/null +++ b/coral-service/frontend/env.local @@ -0,0 +1,2 @@ +# Base URL of the Coral Service API, default is http://localhost:8080 +NEXT_PUBLIC_CORAL_SERVICE_API_URL="http://localhost:8080"