Skip to content

Commit 1916ef9

Browse files
committed
create PruneShuffleAndSort physical rule
1 parent 88fc8db commit 1916ef9

File tree

4 files changed

+87
-9
lines changed

4 files changed

+87
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
3434
import org.apache.spark.sql.catalyst.util.truncatedString
3535
import org.apache.spark.sql.dynamicpruning.PlanDynamicPruningFilters
3636
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
37-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
37+
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReuseExchange}
3838
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
3939
import org.apache.spark.sql.internal.SQLConf
4040
import org.apache.spark.sql.streaming.OutputMode
@@ -279,6 +279,7 @@ object QueryExecution {
279279
PlanDynamicPruningFilters(sparkSession),
280280
PlanSubqueries(sparkSession),
281281
EnsureRequirements(sparkSession.sessionState.conf),
282+
PruneShuffleAndSort(),
282283
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
283284
sparkSession.sessionState.columnarRules),
284285
CollapseCodegenStages(sparkSession.sessionState.conf),

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
217217
}
218218

219219
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
220-
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
221-
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
222-
child.outputPartitioning match {
223-
case lower: HashPartitioning if upper.semanticEquals(lower) => child
224-
case _ => operator
225-
}
226220
case operator: SparkPlan =>
227221
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
228222
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.exchange
19+
20+
import org.apache.spark.sql.catalyst.expressions.SortOrder
21+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
24+
25+
case class PruneShuffleAndSort() extends Rule[SparkPlan] {
26+
27+
override def apply(plan: SparkPlan): SparkPlan = {
28+
plan.transformUp {
29+
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
30+
child.outputPartitioning match {
31+
case lower: HashPartitioning if upper.semanticEquals(lower) => child
32+
case _ @ PartitioningCollection(partitionings) =>
33+
if (partitionings.exists{case lower: HashPartitioning =>
34+
upper.semanticEquals(lower)
35+
}) {
36+
child
37+
} else {
38+
operator
39+
}
40+
case _ => operator
41+
}
42+
case SortExec(upper, false, child, _)
43+
if SortOrder.orderingSatisfies(child.outputOrdering, upper) => child
44+
case subPlan: SparkPlan => subPlan
45+
}
46+
}
47+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan,
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2828
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
29-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
29+
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
3030
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3131
import org.apache.spark.sql.functions._
3232
import org.apache.spark.sql.internal.SQLConf
@@ -482,7 +482,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
482482
val inputPlan = ShuffleExchangeExec(
483483
partitioning,
484484
DummySparkPlan(outputPartitioning = partitioning))
485-
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
485+
val outputPlan = PruneShuffleAndSort().apply(inputPlan)
486486
assertDistributionRequirementsAreSatisfied(outputPlan)
487487
if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {
488488
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
@@ -775,6 +775,42 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
775775
}
776776
}
777777

778+
test("SPARK-28148: repartition after join is not optimized away") {
779+
val df1 = spark.range(0, 5000000, 1, 5)
780+
val df2 = spark.range(0, 10000000, 1, 5)
781+
782+
// non global sort order and partitioning should be reusable after left join
783+
val outputPlan1 = df1.join(df2, Seq("id"), "left")
784+
.repartition(df1("id"))
785+
.sortWithinPartitions(df1("id"))
786+
.queryExecution.executedPlan
787+
val numSorts1 = outputPlan1.collect{case s: SortExec => s }
788+
val numShuffles1 = outputPlan1.collect{case s: ShuffleExchangeExec => s }
789+
assert(numSorts1.length == 2)
790+
assert(numShuffles1.length == 2)
791+
792+
// non global sort order and partitioning should be reusable after inner join
793+
val outputPlan2 = df1.join(df2, Seq("id"))
794+
.repartition(df1("id"))
795+
.sortWithinPartitions(df1("id"))
796+
.queryExecution.executedPlan
797+
798+
val numSorts2 = outputPlan2.collect{case s: SortExec => s }
799+
val numShuffles2 = outputPlan2.collect{case s: ShuffleExchangeExec => s }
800+
assert(numSorts2.length == 2)
801+
assert(numShuffles2.length == 2)
802+
803+
// global sort should not be removed
804+
val outputPlan3 = df1.join(df2, Seq("id"))
805+
.orderBy(df1("id"))
806+
.queryExecution.executedPlan
807+
808+
val numSorts3 = outputPlan3.collect{case s: SortExec => s }
809+
val numShuffles3 = outputPlan3.collect{case s: ShuffleExchangeExec => s }
810+
assert(numSorts3.length == 3)
811+
assert(numShuffles3.length == 3)
812+
}
813+
778814
test("SPARK-24500: create union with stream of children") {
779815
val df = Union(Stream(
780816
Range(1, 1, 1, 1),

0 commit comments

Comments
 (0)