Skip to content

Commit 81f882b

Browse files
noorallzhuzhurk
authored andcommitted
[FLINK-36989][runtime] Fix scheduler benchmark regression caused by ConsumedSubpartitionContext
1 parent 3084561 commit 81f882b

File tree

4 files changed

+69
-55
lines changed

4 files changed

+69
-55
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/deployment/ConsumedSubpartitionContext.java

+24-18
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,17 @@
2020

2121
import org.apache.flink.runtime.executiongraph.IndexRange;
2222
import org.apache.flink.runtime.executiongraph.IndexRangeUtil;
23-
import org.apache.flink.runtime.executiongraph.IntermediateResult;
2423
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
24+
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
2525

2626
import java.io.Serializable;
2727
import java.util.ArrayList;
2828
import java.util.Collection;
2929
import java.util.Collections;
30-
import java.util.HashMap;
31-
import java.util.Iterator;
3230
import java.util.LinkedHashMap;
3331
import java.util.List;
3432
import java.util.Map;
33+
import java.util.function.Function;
3534

3635
import static org.apache.flink.util.Preconditions.checkNotNull;
3736
import static org.apache.flink.util.Preconditions.checkState;
@@ -113,34 +112,41 @@ public IndexRange getConsumedSubpartitionRange(int shuffleDescriptorIndex) {
113112
*
114113
* @param consumedSubpartitionGroups a mapping of consumed partition index ranges to
115114
* subpartition ranges.
116-
* @param consumedResultPartitions an iterator of {@link IntermediateResultPartitionID} for the
117-
* consumed result partitions.
118-
* @param partitions all partition ids of consumed {@link IntermediateResult}.
115+
* @param consumedPartitionGroup partition group consumed by the task.
116+
* @param partitionIdRetriever a function that retrieves the {@link
117+
* IntermediateResultPartitionID} for a given index.
119118
* @return a {@link ConsumedSubpartitionContext} instance constructed from the input parameters.
120119
*/
121120
public static ConsumedSubpartitionContext buildConsumedSubpartitionContext(
122121
Map<IndexRange, IndexRange> consumedSubpartitionGroups,
123-
Iterator<IntermediateResultPartitionID> consumedResultPartitions,
124-
IntermediateResultPartitionID[] partitions) {
125-
Map<IntermediateResultPartitionID, Integer> partitionIdToShuffleDescriptorIndexMap =
126-
new HashMap<>();
127-
while (consumedResultPartitions.hasNext()) {
128-
IntermediateResultPartitionID partitionId = consumedResultPartitions.next();
129-
partitionIdToShuffleDescriptorIndexMap.put(
130-
partitionId, partitionIdToShuffleDescriptorIndexMap.size());
122+
ConsumedPartitionGroup consumedPartitionGroup,
123+
Function<Integer, IntermediateResultPartitionID> partitionIdRetriever) {
124+
Map<IntermediateResultPartitionID, Integer> resultPartitionsInOrder =
125+
consumedPartitionGroup.getResultPartitionsInOrder();
126+
// If only one range is included and the index range size is the same as the number of
127+
// shuffle descriptors, it means that the task will subscribe to all partitions, i.e., the
128+
// partition range is one-to-one corresponding to the shuffle descriptors. Therefore, we can
129+
// directly construct the ConsumedSubpartitionContext using the subpartition range.
130+
if (consumedSubpartitionGroups.size() == 1
131+
&& consumedSubpartitionGroups.keySet().iterator().next().size()
132+
== resultPartitionsInOrder.size()) {
133+
return buildConsumedSubpartitionContext(
134+
resultPartitionsInOrder.size(),
135+
consumedSubpartitionGroups.values().iterator().next());
131136
}
132137

133138
Map<IndexRange, IndexRange> consumedShuffleDescriptorToSubpartitionRangeMap =
134139
new LinkedHashMap<>();
135140
for (Map.Entry<IndexRange, IndexRange> entry : consumedSubpartitionGroups.entrySet()) {
136141
IndexRange partitionRange = entry.getKey();
137142
IndexRange subpartitionRange = entry.getValue();
143+
// The shuffle descriptor index is consistent with the index in resultPartitionsInOrder.
138144
IndexRange shuffleDescriptorRange =
139145
new IndexRange(
140-
partitionIdToShuffleDescriptorIndexMap.get(
141-
partitions[partitionRange.getStartIndex()]),
142-
partitionIdToShuffleDescriptorIndexMap.get(
143-
partitions[partitionRange.getEndIndex()]));
146+
resultPartitionsInOrder.get(
147+
partitionIdRetriever.apply(partitionRange.getStartIndex())),
148+
resultPartitionsInOrder.get(
149+
partitionIdRetriever.apply(partitionRange.getEndIndex())));
144150
checkState(
145151
partitionRange.size() == shuffleDescriptorRange.size()
146152
&& !consumedShuffleDescriptorToSubpartitionRangeMap.containsKey(

flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java

+3-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import java.io.IOException;
5454
import java.io.Serializable;
5555
import java.util.ArrayList;
56-
import java.util.Arrays;
5756
import java.util.Collection;
5857
import java.util.HashMap;
5958
import java.util.List;
@@ -151,6 +150,7 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
151150

152151
IntermediateDataSetID resultId = consumedIntermediateResult.getId();
153152
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
153+
IntermediateResultPartition[] partitions = consumedIntermediateResult.getPartitions();
154154

155155
inputGates.add(
156156
new InputGateDeploymentDescriptor(
@@ -160,10 +160,8 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
160160
executionVertex
161161
.getExecutionVertexInputInfo(resultId)
162162
.getConsumedSubpartitionGroups(),
163-
consumedPartitionGroup.iterator(),
164-
Arrays.stream(consumedIntermediateResult.getPartitions())
165-
.map(IntermediateResultPartition::getPartitionId)
166-
.toArray(IntermediateResultPartitionID[]::new)),
163+
consumedPartitionGroup,
164+
index -> partitions[index].getPartitionId()),
167165
consumedPartitionGroup.size(),
168166
getConsumedPartitionShuffleDescriptors(
169167
consumedIntermediateResult,

flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java

+19-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
import java.util.Collections;
2929
import java.util.Iterator;
30+
import java.util.LinkedHashMap;
3031
import java.util.List;
32+
import java.util.Map;
3133
import java.util.concurrent.atomic.AtomicInteger;
3234

3335
import static org.apache.flink.util.Preconditions.checkArgument;
@@ -40,7 +42,10 @@
4042
*/
4143
public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartitionID> {
4244

43-
private final List<IntermediateResultPartitionID> resultPartitions;
45+
// The key is the result partition ID, the value is the index of the result partition in the
46+
// original construction list.
47+
private final Map<IntermediateResultPartitionID, Integer> resultPartitionsInOrder =
48+
new LinkedHashMap<>();
4449

4550
private final AtomicInteger unfinishedPartitions;
4651

@@ -64,13 +69,15 @@ private ConsumedPartitionGroup(
6469
this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID();
6570
this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType);
6671

67-
// Sanity check: all the partitions in one ConsumedPartitionGroup should have the same
68-
// IntermediateDataSetID
69-
for (IntermediateResultPartitionID resultPartition : resultPartitions) {
72+
for (int i = 0; i < resultPartitions.size(); i++) {
73+
// Sanity check: all the partitions in one ConsumedPartitionGroup should have the same
74+
// IntermediateDataSetID
75+
IntermediateResultPartitionID resultPartition = resultPartitions.get(i);
7076
checkArgument(
7177
resultPartition.getIntermediateDataSetID().equals(this.intermediateDataSetID));
78+
79+
resultPartitionsInOrder.put(resultPartition, i);
7280
}
73-
this.resultPartitions = resultPartitions;
7481

7582
this.unfinishedPartitions = new AtomicInteger(resultPartitions.size());
7683
}
@@ -92,15 +99,19 @@ public static ConsumedPartitionGroup fromSinglePartition(
9299

93100
@Override
94101
public Iterator<IntermediateResultPartitionID> iterator() {
95-
return resultPartitions.iterator();
102+
return resultPartitionsInOrder.keySet().iterator();
103+
}
104+
105+
public Map<IntermediateResultPartitionID, Integer> getResultPartitionsInOrder() {
106+
return Collections.unmodifiableMap(resultPartitionsInOrder);
96107
}
97108

98109
public int size() {
99-
return resultPartitions.size();
110+
return resultPartitionsInOrder.size();
100111
}
101112

102113
public boolean isEmpty() {
103-
return resultPartitions.isEmpty();
114+
return resultPartitionsInOrder.isEmpty();
104115
}
105116

106117
/**

flink-runtime/src/test/java/org/apache/flink/runtime/deployment/ConsumedSubpartitionContextTest.java

+23-24
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.runtime.executiongraph.IndexRange;
2222
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
2323
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
24+
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
2425

2526
import org.junit.jupiter.api.Test;
2627

@@ -29,6 +30,7 @@
2930
import java.util.List;
3031
import java.util.Map;
3132

33+
import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
3234
import static org.assertj.core.api.Assertions.assertThat;
3335

3436
/** Tests for {@link ConsumedSubpartitionContext}. */
@@ -40,17 +42,13 @@ void testBuildConsumedSubpartitionContextWithGroups() {
4042
new IndexRange(0, 1), new IndexRange(0, 2),
4143
new IndexRange(2, 3), new IndexRange(3, 5));
4244

43-
List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();
44-
45-
IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
46-
for (int i = 0; i < partitions.length; i++) {
47-
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
48-
consumedPartitionIds.add(partitions[i]);
49-
}
45+
List<IntermediateResultPartitionID> partitions = createPartitions();
46+
ConsumedPartitionGroup consumedPartitionGroup =
47+
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);
5048

5149
ConsumedSubpartitionContext context =
5250
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
53-
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
51+
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);
5452

5553
assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(4);
5654

@@ -71,17 +69,13 @@ void testBuildConsumedSubpartitionContextWithUnorderedGroups() {
7169
new IndexRange(3, 3), new IndexRange(1, 1),
7270
new IndexRange(0, 0), new IndexRange(0, 1));
7371

74-
List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();
75-
76-
IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
77-
for (int i = 0; i < partitions.length; i++) {
78-
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
79-
consumedPartitionIds.add(partitions[i]);
80-
}
72+
List<IntermediateResultPartitionID> partitions = createPartitions();
73+
ConsumedPartitionGroup consumedPartitionGroup =
74+
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);
8175

8276
ConsumedSubpartitionContext context =
8377
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
84-
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
78+
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);
8579

8680
assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(2);
8781

@@ -100,17 +94,13 @@ void testBuildConsumedSubpartitionContextWithOverlapGroups() {
10094
new IndexRange(0, 3), new IndexRange(1, 1),
10195
new IndexRange(0, 1), new IndexRange(2, 2));
10296

103-
List<IntermediateResultPartitionID> consumedPartitionIds = new ArrayList<>();
104-
105-
IntermediateResultPartitionID[] partitions = new IntermediateResultPartitionID[4];
106-
for (int i = 0; i < partitions.length; i++) {
107-
partitions[i] = new IntermediateResultPartitionID(new IntermediateDataSetID(), i);
108-
consumedPartitionIds.add(partitions[i]);
109-
}
97+
List<IntermediateResultPartitionID> partitions = createPartitions();
98+
ConsumedPartitionGroup consumedPartitionGroup =
99+
ConsumedPartitionGroup.fromMultiplePartitions(4, partitions, BLOCKING);
110100

111101
ConsumedSubpartitionContext context =
112102
ConsumedSubpartitionContext.buildConsumedSubpartitionContext(
113-
consumedSubpartitionGroups, consumedPartitionIds.iterator(), partitions);
103+
consumedSubpartitionGroups, consumedPartitionGroup, partitions::get);
114104

115105
assertThat(context.getNumConsumedShuffleDescriptors()).isEqualTo(4);
116106

@@ -144,4 +134,13 @@ void testBuildConsumedSubpartitionContextWithRange() {
144134
IndexRange subpartitionRange = context.getConsumedSubpartitionRange(2);
145135
assertThat(subpartitionRange).isEqualTo(consumedSubpartitionRange);
146136
}
137+
138+
private static List<IntermediateResultPartitionID> createPartitions() {
139+
List<IntermediateResultPartitionID> partitions = new ArrayList<>();
140+
IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
141+
for (int i = 0; i < 4; i++) {
142+
partitions.add(new IntermediateResultPartitionID(intermediateDataSetID, i));
143+
}
144+
return partitions;
145+
}
147146
}

0 commit comments

Comments
 (0)