diff --git a/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeIncrementalTransformer.java b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeIncrementalTransformer.java index 5b59c11d1..9681166ae 100644 --- a/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeIncrementalTransformer.java +++ b/coral-incremental/src/main/java/com/linkedin/coral/incremental/RelNodeIncrementalTransformer.java @@ -7,10 +7,14 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.calcite.plan.RelOptSchema; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.prepare.RelOptTableImpl; import org.apache.calcite.rel.RelNode; @@ -29,16 +33,63 @@ public class RelNodeIncrementalTransformer { - private RelNodeIncrementalTransformer() { + private final String TABLE_NAME_PREFIX = "Table#"; + private final String DELTA_SUFFIX = "_delta"; + + private RelOptSchema relOptSchema; + private Map snapshotRelNodes; + private Map deltaRelNodes; + private RelNode tempLastRelNode; + + public RelNodeIncrementalTransformer() { + relOptSchema = null; + snapshotRelNodes = new LinkedHashMap<>(); + deltaRelNodes = new LinkedHashMap<>(); + tempLastRelNode = null; } - public static RelNode convertRelIncremental(RelNode originalNode) { + /** + * Returns snapshotRelNodes with deterministic keys. + */ + public Map getSnapshotRelNodes() { + Map deterministicSnapshotRelNodes = new LinkedHashMap<>(); + for (String description : snapshotRelNodes.keySet()) { + deterministicSnapshotRelNodes.put(getDeterministicDescriptionFromDescription(description, false), + snapshotRelNodes.get(description)); + } + return deterministicSnapshotRelNodes; + } + + /** + * Returns deltaRelNodes with deterministic keys. + */ + public Map getDeltaRelNodes() { + Map deterministicDeltaRelNodes = new LinkedHashMap<>(); + for (String description : deltaRelNodes.keySet()) { + deterministicDeltaRelNodes.put(getDeterministicDescriptionFromDescription(description, true), + deltaRelNodes.get(description)); + } + return deterministicDeltaRelNodes; + } + + /** + * Convert an input RelNode to an incremental RelNode. Populates snapshotRelNodes and deltaRelNodes. + * @param originalNode input RelNode to generate an incremental version for. + */ + public RelNode convertRelIncremental(RelNode originalNode) { RelShuttle converter = new RelShuttleImpl() { @Override public RelNode visit(TableScan scan) { RelOptTable originalTable = scan.getTable(); + + // Set RelNodeIncrementalTransformer class relOptSchema if not already set + if (relOptSchema == null) { + relOptSchema = originalTable.getRelOptSchema(); + } + + // Create delta scan List incrementalNames = new ArrayList<>(originalTable.getQualifiedName()); - String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + "_delta"; + String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + DELTA_SUFFIX; incrementalNames.add(deltaTableName); RelOptTable incrementalTable = RelOptTableImpl.create(originalTable.getRelOptSchema(), originalTable.getRowType(), incrementalNames, null); @@ -54,12 +105,34 @@ public RelNode visit(LogicalJoin join) { RexBuilder rexBuilder = join.getCluster().getRexBuilder(); + // Check if we can replace the left and right nodes with a scan of a materialized table + String leftDescription = getDescriptionFromRelNode(left, false); + String leftIncrementalDescription = getDescriptionFromRelNode(left, true); + if (snapshotRelNodes.containsKey(leftDescription)) { + left = + susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(leftDescription, false), left); + incrementalLeft = susbstituteWithMaterializedView( + getDeterministicDescriptionFromDescription(leftIncrementalDescription, true), incrementalLeft); + } + String rightDescription = getDescriptionFromRelNode(right, false); + String rightIncrementalDescription = getDescriptionFromRelNode(right, true); + if (snapshotRelNodes.containsKey(rightDescription)) { + right = susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(rightDescription, false), + right); + incrementalRight = susbstituteWithMaterializedView( + getDeterministicDescriptionFromDescription(rightIncrementalDescription, true), incrementalRight); + } + + // We need to do this in the join to get potentially updated left and right nodes + tempLastRelNode = createProjectOverJoin(join, left, right, rexBuilder); + LogicalProject p1 = createProjectOverJoin(join, left, incrementalRight, rexBuilder); LogicalProject p2 = createProjectOverJoin(join, incrementalLeft, right, rexBuilder); LogicalProject p3 = createProjectOverJoin(join, incrementalLeft, incrementalRight, rexBuilder); LogicalUnion unionAllJoins = LogicalUnion.create(Arrays.asList(LogicalUnion.create(Arrays.asList(p1, p2), true), p3), true); + return unionAllJoins; } @@ -72,7 +145,16 @@ public RelNode visit(LogicalFilter filter) { @Override public RelNode visit(LogicalProject project) { RelNode transformedChild = convertRelIncremental(project.getInput()); - return LogicalProject.create(transformedChild, project.getProjects(), project.getRowType()); + RelNode materializedProject = getTempLastRelNode(); + if (materializedProject != null) { + snapshotRelNodes.put(getDescriptionFromRelNode(project, false), materializedProject); + } else { + snapshotRelNodes.put(getDescriptionFromRelNode(project, false), project); + } + LogicalProject transformedProject = + LogicalProject.create(transformedChild, project.getProjects(), project.getRowType()); + deltaRelNodes.put(getDescriptionFromRelNode(project, true), transformedProject); + return transformedProject; } @Override @@ -93,8 +175,67 @@ public RelNode visit(LogicalAggregate aggregate) { return originalNode.accept(converter); } - private static LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right, - RexBuilder rexBuilder) { + /** + * Returns the tempLastRelNode and sets the variable back to null. Should only be called once for each retrieval + * instance since subsequent consecutive calls will yield null. + */ + private RelNode getTempLastRelNode() { + RelNode currentTempLastRelNode = tempLastRelNode; + tempLastRelNode = null; + return currentTempLastRelNode; + } + + /** + * Returns the corresponding description for a given RelNode by extracting the identifier (ex. the identifier for + * LogicalProject#22 is 22) and prepending the TABLE_NAME_PREFIX. Depending on the delta value, a delta suffix may be + * appended. + * @param relNode RelNode from which the identifier will be retrieved. + * @param delta configure whether to get the delta name + */ + private String getDescriptionFromRelNode(RelNode relNode, boolean delta) { + String identifier = relNode.getDescription().split("#")[1]; + String description = TABLE_NAME_PREFIX + identifier; + if (delta) { + return description + DELTA_SUFFIX; + } + return description; + } + + /** + * Returns a description based on mapping index order that will stay the same across different runs of the same + * query. The description consists of the table prefix, the index, and optionally, the delta suffix. + * @param description output from calling getDescriptionFromRelNode() + * @param delta configure whether to get the delta name + */ + private String getDeterministicDescriptionFromDescription(String description, boolean delta) { + if (delta) { + List deltaKeyOrdering = new ArrayList<>(deltaRelNodes.keySet()); + return TABLE_NAME_PREFIX + deltaKeyOrdering.indexOf(description) + DELTA_SUFFIX; + } else { + List snapshotKeyOrdering = new ArrayList<>(snapshotRelNodes.keySet()); + return TABLE_NAME_PREFIX + snapshotKeyOrdering.indexOf(description); + } + } + + /** + * Accepts a table name and RelNode and creates a TableScan over the RelNode using the class relOptSchema. + * @param relOptTableName table name corresponding to table to scan over + * @param relNode top-level RelNode that will be replaced with the TableScan + */ + private TableScan susbstituteWithMaterializedView(String relOptTableName, RelNode relNode) { + RelOptTable table = + RelOptTableImpl.create(relOptSchema, relNode.getRowType(), Collections.singletonList(relOptTableName), null); + return LogicalTableScan.create(relNode.getCluster(), table); + } + + /** Creates a LogicalProject whose input is an incremental LogicalJoin node that is constructed from a left and right + * RelNode and LogicalJoin. + * @param join LogicalJoin to create the incremental join from + * @param left left RelNode child of the incremental join + * @param right right RelNode child of the incremental join + * @param rexBuilder RexBuilder for LogicalProject creation + */ + private LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right, RexBuilder rexBuilder) { LogicalJoin incrementalJoin = LogicalJoin.create(left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType()); ArrayList projects = new ArrayList<>(); diff --git a/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelToIncrementalSqlConverterTest.java b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelToIncrementalSqlConverterTest.java index 3ac0cd683..2c5262ce5 100644 --- a/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelToIncrementalSqlConverterTest.java +++ b/coral-incremental/src/test/java/com/linkedin/coral/incremental/RelToIncrementalSqlConverterTest.java @@ -7,6 +7,8 @@ import java.io.File; import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; import org.apache.calcite.rel.RelNode; import org.apache.calcite.sql.SqlNode; @@ -41,7 +43,8 @@ public void afterClass() throws IOException { } public String convert(RelNode relNode) { - RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode); + RelNodeIncrementalTransformer transformer = new RelNodeIncrementalTransformer(); + RelNode incrementalRelNode = transformer.convertRelIncremental(relNode); CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter(); SqlNode sqlNode = converter.convert(incrementalRelNode); return sqlNode.toSqlString(converter.INSTANCE).getSql(); @@ -52,6 +55,28 @@ public String getIncrementalModification(String sql) { return convert(originalRelNode); } + public void checkAllSnapshotAndDeltaQueries(String sql, Map snapshotExpected, + Map deltaExpected) { + RelNode originalRelNode = hiveToRelConverter.convertSql(sql); + CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter(); + RelNodeIncrementalTransformer transformer = new RelNodeIncrementalTransformer(); + transformer.convertRelIncremental(originalRelNode); + Map snapshotRelNodes = transformer.getSnapshotRelNodes(); + Map deltaRelNodes = transformer.getDeltaRelNodes(); + for (String key : snapshotRelNodes.keySet()) { + RelNode actualSnapshotRelNode = snapshotRelNodes.get(key); + SqlNode sqlNode = converter.convert(actualSnapshotRelNode); + String actualSql = sqlNode.toSqlString(converter.INSTANCE).getSql(); + assertEquals(actualSql, snapshotExpected.get(key)); + } + for (String key : deltaRelNodes.keySet()) { + RelNode actualDeltaRelNode = deltaRelNodes.get(key); + SqlNode sqlNode = converter.convert(actualDeltaRelNode); + String actualSql = sqlNode.toSqlString(converter.INSTANCE).getSql(); + assertEquals(actualSql, deltaExpected.get(key)); + } + } + @Test public void testSimpleSelectAll() { String sql = "SELECT * FROM test.foo"; @@ -81,41 +106,6 @@ public void testJoinWithFilter() { assertEquals(getIncrementalModification(sql), expected); } - @Test - public void testJoinWithNestedFilter() { - String sql = - "WITH tmp AS (SELECT * from test.bar1 WHERE test.bar1.x > 10), tmp2 AS (SELECT * from test.bar2) SELECT * FROM tmp JOIN tmp2 ON tmp.x = tmp2.x"; - String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n" - + "WHERE bar1.x > 10) AS t\n" + "INNER JOIN test.bar2_delta AS bar2_delta ON t.x = bar2_delta.x\n" - + "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta\n" - + "WHERE bar1_delta.x > 10) AS t0\n" + "INNER JOIN test.bar2 AS bar2 ON t0.x = bar2.x) AS t1\n" + "UNION ALL\n" - + "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n" - + "WHERE bar1_delta0.x > 10) AS t2\n" + "INNER JOIN test.bar2_delta AS bar2_delta0 ON t2.x = bar2_delta0.x"; - assertEquals(getIncrementalModification(sql), expected); - } - - @Test - public void testNestedJoin() { - String sql = - "WITH tmp AS (SELECT * FROM test.bar1 INNER JOIN test.bar2 ON test.bar1.x = test.bar2.x) SELECT * FROM tmp INNER JOIN test.bar3 ON tmp.x = test.bar3.x"; - String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n" - + "INNER JOIN test.bar2 AS bar2 ON bar1.x = bar2.x\n" - + "INNER JOIN test.bar3_delta AS bar3_delta ON bar1.x = bar3_delta.x\n" + "UNION ALL\n" + "SELECT *\n" - + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar10\n" - + "INNER JOIN test.bar2_delta AS bar2_delta ON bar10.x = bar2_delta.x\n" + "UNION ALL\n" + "SELECT *\n" - + "FROM test.bar1_delta AS bar1_delta\n" + "INNER JOIN test.bar2 AS bar20 ON bar1_delta.x = bar20.x) AS t\n" - + "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n" - + "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0\n" - + "INNER JOIN test.bar3 AS bar3 ON t0.x = bar3.x) AS t1\n" + "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n" - + "FROM (SELECT *\n" + "FROM test.bar1 AS bar11\n" - + "INNER JOIN test.bar2_delta AS bar2_delta1 ON bar11.x = bar2_delta1.x\n" + "UNION ALL\n" + "SELECT *\n" - + "FROM test.bar1_delta AS bar1_delta1\n" + "INNER JOIN test.bar2 AS bar21 ON bar1_delta1.x = bar21.x) AS t2\n" - + "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta2\n" - + "INNER JOIN test.bar2_delta AS bar2_delta2 ON bar1_delta2.x = bar2_delta2.x) AS t3\n" - + "INNER JOIN test.bar3_delta AS bar3_delta0 ON t3.x = bar3_delta0.x"; - assertEquals(getIncrementalModification(sql), expected); - } - @Test public void testUnion() { String sql = "SELECT * FROM test.bar1 UNION SELECT * FROM test.bar2 UNION SELECT * FROM test.bar3"; @@ -143,4 +133,68 @@ public void testSelectSpecificJoin() { + "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0"; assertEquals(getIncrementalModification(sql), expected); } + + @Test + public void testNestedJoin() { + String nestedJoin = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1"; + String sql = "SELECT a2, g1 FROM (" + nestedJoin + ") AS nj JOIN test.gamma ON nj.a2 = test.gamma.g2"; + Map snapshotExpected = new LinkedHashMap<>(); + snapshotExpected.put("Table#0", + "SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1"); + snapshotExpected.put("Table#1", + "SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2"); + Map deltaExpected = new LinkedHashMap<>(); + deltaExpected.put("Table#0_delta", + "SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha AS alpha0\n" + + "INNER JOIN test.beta_delta AS beta_delta ON alpha0.a1 = beta_delta.b1\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.alpha_delta AS alpha_delta\n" + + "INNER JOIN test.beta AS beta0 ON alpha_delta.a1 = beta0.b1) AS t\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.alpha_delta AS alpha_delta0\n" + + "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0"); + deltaExpected.put("Table#1_delta", + "SELECT t3.a2, t3.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0 AS Table#00\n" + + "INNER JOIN test.gamma_delta AS gamma_delta ON Table#00.a2 = gamma_delta.g2\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n" + + "INNER JOIN test.gamma AS gamma0 ON Table#0_delta.a2 = gamma0.g2) AS t2\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM Table#0_delta AS Table#0_delta0\n" + + "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t3"); + checkAllSnapshotAndDeltaQueries(sql, snapshotExpected, deltaExpected); + } + + @Test + public void testThreeNestedJoins() { + String nestedJoin1 = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1"; + String nestedJoin2 = "SELECT a2, g1 FROM (" + nestedJoin1 + ") AS nj1 JOIN test.gamma ON nj1.a2 = test.gamma.g2"; + String sql = "SELECT g1, e2 FROM (" + nestedJoin2 + ") AS nj2 JOIN test.epsilon ON nj2.g1 = test.epsilon.e1"; + Map snapshotExpected = new LinkedHashMap<>(); + snapshotExpected.put("Table#0", + "SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1"); + snapshotExpected.put("Table#1", + "SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2"); + snapshotExpected.put("Table#2", + "SELECT *\n" + "FROM Table#1 AS Table#1\n" + "INNER JOIN test.epsilon AS epsilon ON Table#1.g1 = epsilon.e1"); + Map deltaExpected = new LinkedHashMap<>(); + deltaExpected.put("Table#0_delta", + "SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha AS alpha0\n" + + "INNER JOIN test.beta_delta AS beta_delta ON alpha0.a1 = beta_delta.b1\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.alpha_delta AS alpha_delta\n" + + "INNER JOIN test.beta AS beta0 ON alpha_delta.a1 = beta0.b1) AS t\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM test.alpha_delta AS alpha_delta0\n" + + "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0"); + deltaExpected.put("Table#1_delta", + "SELECT t3.a2, t3.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0 AS Table#00\n" + + "INNER JOIN test.gamma_delta AS gamma_delta ON Table#00.a2 = gamma_delta.g2\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n" + + "INNER JOIN test.gamma AS gamma0 ON Table#0_delta.a2 = gamma0.g2) AS t2\n" + "UNION ALL\n" + "SELECT *\n" + + "FROM Table#0_delta AS Table#0_delta0\n" + + "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t3"); + deltaExpected.put("Table#2_delta", + "SELECT t6.g1, t6.e2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#1 AS Table#10\n" + + "INNER JOIN test.epsilon_delta AS epsilon_delta ON Table#10.g1 = epsilon_delta.e1\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta\n" + + "INNER JOIN test.epsilon AS epsilon0 ON Table#1_delta.g1 = epsilon0.e1) AS t5\n" + "UNION ALL\n" + + "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta0\n" + + "INNER JOIN test.epsilon_delta AS epsilon_delta0 ON Table#1_delta0.g1 = epsilon_delta0.e1) AS t6"); + checkAllSnapshotAndDeltaQueries(sql, snapshotExpected, deltaExpected); + } } diff --git a/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java b/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java index 232705ed4..7384be9a8 100644 --- a/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java +++ b/coral-incremental/src/test/java/com/linkedin/coral/incremental/TestUtils.java @@ -60,6 +60,11 @@ public static void initializeViews(HiveConf conf) throws HiveException, MetaExce run(driver, "CREATE TABLE IF NOT EXISTS test.bar1(x int, y double)"); run(driver, "CREATE TABLE IF NOT EXISTS test.bar2(x int, y double)"); run(driver, "CREATE TABLE IF NOT EXISTS test.bar3(x int, y double)"); + + run(driver, "CREATE TABLE IF NOT EXISTS test.alpha(a1 int, a2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.beta(b1 int, b2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.gamma(g1 int, g2 double)"); + run(driver, "CREATE TABLE IF NOT EXISTS test.epsilon(e1 int, e2 double)"); } public static HiveConf loadResourceHiveConf() { diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/IncrementalUtils.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/IncrementalUtils.java index 33fbc9023..055066ff6 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/IncrementalUtils.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/IncrementalUtils.java @@ -18,7 +18,8 @@ public class IncrementalUtils { public static String getSparkIncrementalQueryFromUserSql(String query) { RelNode originalNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query); - RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(originalNode); + RelNodeIncrementalTransformer transformer = new RelNodeIncrementalTransformer(); + RelNode incrementalRelNode = transformer.convertRelIncremental(originalNode); CoralSpark coralSpark = CoralSpark.create(incrementalRelNode); return coralSpark.getSparkSql(); }