Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-36608][table-runtime] Support dynamic StreamGraph optimization for AdaptiveBroadcastJoinOperator #25822

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
</tr>
</thead>
<tbody>
<tr>
<td><h5>table.optimizer.adaptive-broadcast-join.strategy</h5><br> <span class="label label-primary">Batch</span></td>
<td style="word-wrap: break-word;">none</td>
<td><p>Enum</p></td>
<td>Flink will perform broadcast hash join optimization when the runtime statistics on one side of a join operator is less than the threshold `table.optimizer.join.broadcast-threshold`. The value of this configuration option decides when Flink should perform this optimization. AUTO means Flink will automatically choose the timing for optimization, RUNTIME_ONLY means broadcast hash join optimization is only performed at runtime, and NONE means the optimization is only carried out at compile time.<br /><br />Possible values:<ul><li>"auto": Flink will automatically choose the timing for optimization</li><li>"runtime_only": Broadcast hash join optimization is only performed at runtime.</li><li>"none": Broadcast hash join optimization is only carried out at compile time.</li></ul></td>
</tr>
<tr>
<td><h5>table.optimizer.agg-phase-strategy</h5><br> <span class="label label-primary">Batch</span> <span class="label label-primary">Streaming</span></td>
<td style="word-wrap: break-word;">AUTO</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class IntermediateResult {
private final int numParallelProducers;

private final ExecutionPlanSchedulingContext executionPlanSchedulingContext;
private final boolean produceBroadcastResult;

private int partitionsAssigned;

Expand Down Expand Up @@ -102,6 +103,8 @@ public IntermediateResult(
this.shuffleDescriptorCache = new HashMap<>();

this.executionPlanSchedulingContext = checkNotNull(executionPlanSchedulingContext);

this.produceBroadcastResult = intermediateDataSet.isBroadcast();
}

public boolean areAllConsumerVerticesCreated() {
Expand Down Expand Up @@ -207,6 +210,10 @@ public boolean isForward() {
return intermediateDataSet.isForward();
}

public boolean isEveryConsumerConsumeAllSubPartitions() {
return !produceBroadcastResult && intermediateDataSet.isBroadcast();
}

public int getConnectionIndex() {
return connectionIndex;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public interface IntermediateResultInfo {
*/
boolean isBroadcast();

/**
* Indicates whether every downstream consumer needs to consume all produced sub-partitions.
*
* @return true if every downstream consumer needs to consume all produced sub-partitions, false
* otherwise.
*/
boolean isEveryConsumerConsumeAllSubPartitions();

/**
* Whether it is a pointwise result.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public static Map<IntermediateDataSetID, JobVertexInputInfo> computeVertexInputI
parallelism,
input::getNumSubpartitions,
isDynamicGraph,
input.isBroadcast()));
input.isBroadcast(),
input.isEveryConsumerConsumeAllSubPartitions()));
}
}

Expand Down Expand Up @@ -124,6 +125,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
1,
() -> numOfSubpartitionsRetriever.apply(start),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(index, partitionRange, subpartitionRange));
Expand All @@ -145,6 +147,7 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
numConsumers,
() -> numOfSubpartitionsRetriever.apply(finalPartitionNum),
isDynamicGraph,
false,
false);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
Expand All @@ -165,14 +168,16 @@ static JobVertexInputInfo computeVertexInputInfoForPointwise(
* @param numOfSubpartitionsRetriever a retriever to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed {@link JobVertexInputInfo}
*/
static JobVertexInputInfo computeVertexInputInfoForAllToAll(
int sourceCount,
int targetCount,
Function<Integer, Integer> numOfSubpartitionsRetriever,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
final List<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<>();
IndexRange partitionRange = new IndexRange(0, sourceCount - 1);
for (int i = 0; i < targetCount; ++i) {
Expand All @@ -182,7 +187,8 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
targetCount,
() -> numOfSubpartitionsRetriever.apply(0),
isDynamicGraph,
isBroadcast);
isBroadcast,
consumeAllSubpartitions);
executionVertexInputInfos.add(
new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange));
}
Expand All @@ -199,6 +205,7 @@ static JobVertexInputInfo computeVertexInputInfoForAllToAll(
* @param numOfSubpartitionsSupplier a supplier to get the number of subpartitions
* @param isDynamicGraph whether is dynamic graph
* @param isBroadcast whether the edge is broadcast
* @param consumeAllSubpartitions whether the edge should consume all subpartitions
* @return the computed subpartition range
*/
@VisibleForTesting
Expand All @@ -207,16 +214,21 @@ static IndexRange computeConsumedSubpartitionRange(
int numConsumers,
Supplier<Integer> numOfSubpartitionsSupplier,
boolean isDynamicGraph,
boolean isBroadcast) {
boolean isBroadcast,
boolean consumeAllSubpartitions) {
int consumerIndex = consumerSubtaskIndex % numConsumers;
if (!isDynamicGraph) {
return new IndexRange(consumerIndex, consumerIndex);
} else {
int numSubpartitions = numOfSubpartitionsSupplier.get();
if (isBroadcast) {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
if (consumeAllSubpartitions) {
return new IndexRange(0, numSubpartitions - 1);
} else {
// broadcast results have only one subpartition, and be consumed multiple times.
checkArgument(numSubpartitions == 1);
return new IndexRange(0, 0);
}
} else {
checkArgument(consumerIndex < numConsumers);
checkArgument(numConsumers <= numSubpartitions);
Expand Down Expand Up @@ -246,6 +258,11 @@ public boolean isBroadcast() {
return intermediateResult.isBroadcast();
}

@Override
public boolean isEveryConsumerConsumeAllSubPartitions() {
return intermediateResult.isEveryConsumerConsumeAllSubPartitions();
}

@Override
public boolean isPointwise() {
return intermediateResult.getConsumingDistributionPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ public void configure(
}
}

public void updateOutputPattern(
DistributionPattern distributionPattern, boolean isBroadcast, boolean isForward) {
checkState(consumers.isEmpty(), "The output job edges have already been added.");
checkState(
numJobEdgesToCreate == 1,
"Modification is not allowed when the subscribing output is reused.");

this.distributionPattern = distributionPattern;
this.isBroadcast = isBroadcast;
this.isForward = isForward;
}

public void increaseNumJobEdgesToCreate() {
this.numJobEdgesToCreate++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ abstract class AbstractBlockingResultInfo implements BlockingResultInfo {
protected final Map<Integer, long[]> subpartitionBytesByPartitionIndex;

AbstractBlockingResultInfo(
IntermediateDataSetID resultId, int numOfPartitions, int numOfSubpartitions) {
IntermediateDataSetID resultId,
int numOfPartitions,
int numOfSubpartitions,
Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
this.resultId = checkNotNull(resultId);
this.numOfPartitions = numOfPartitions;
this.numOfSubpartitions = numOfSubpartitions;
this.subpartitionBytesByPartitionIndex = new HashMap<>();
this.subpartitionBytesByPartitionIndex = subpartitionBytesByPartitionIndex;
}

@Override
Expand All @@ -72,4 +75,9 @@ public void resetPartitionInfo(int partitionIndex) {
int getNumOfRecordedPartitions() {
return subpartitionBytesByPartitionIndex.size();
}

@Override
public Map<Integer, long[]> getSubpartitionBytesByPartitionIndex() {
return new HashMap<>(subpartitionBytesByPartitionIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ public void onNewJobVerticesAdded(List<JobVertex> newVertices, int pendingOperat

// 4. update json plan
getExecutionGraph().setJsonPlan(JsonPlanGenerator.generatePlan(getJobGraph()));

// 5. update the DistributionPattern of the upstream results consumed by the newly created
// JobVertex and aggregate subpartition bytes.
for (JobVertex newVertex : newVertices) {
for (JobEdge input : newVertex.getInputs()) {
tryUpdateResultInfo(input.getSourceId(), input.getDistributionPattern());
Optional.ofNullable(blockingResultInfos.get(input.getSourceId()))
.ifPresent(this::maybeAggregateSubpartitionBytes);
}
}
}

@Override
Expand Down Expand Up @@ -482,15 +492,29 @@ private void updateResultPartitionBytesMetrics(
result.getId(),
(ignored, resultInfo) -> {
if (resultInfo == null) {
resultInfo = createFromIntermediateResult(result);
resultInfo =
createFromIntermediateResult(result, new HashMap<>());
}
resultInfo.recordPartitionInfo(
partitionId.getPartitionNumber(), partitionBytes);
maybeAggregateSubpartitionBytes(resultInfo);
return resultInfo;
});
});
}

private void maybeAggregateSubpartitionBytes(BlockingResultInfo resultInfo) {
IntermediateResult intermediateResult =
getExecutionGraph().getAllIntermediateResults().get(resultInfo.getResultId());

if (intermediateResult.areAllConsumerVerticesCreated()
&& intermediateResult.getConsumerVertices().stream()
.map(this::getExecutionJobVertex)
.allMatch(ExecutionJobVertex::isInitialized)) {
resultInfo.aggregateSubpartitionBytes();
}
}

@Override
public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
List<ExecutionVertex> executionVertices =
Expand Down Expand Up @@ -657,6 +681,7 @@ public void initializeVerticesIfPossible() {
parallelismAndInputInfos.getJobVertexInputInfos(),
createTimestamp);
newlyInitializedJobVertices.add(jobVertex);
consumedResultsInfo.get().forEach(this::maybeAggregateSubpartitionBytes);
}
}
}
Expand Down Expand Up @@ -909,21 +934,24 @@ private static void resetDynamicParallelism(Iterable<JobVertex> vertices) {
}
}

private static BlockingResultInfo createFromIntermediateResult(IntermediateResult result) {
private static BlockingResultInfo createFromIntermediateResult(
IntermediateResult result, Map<Integer, long[]> subpartitionBytesByPartitionIndex) {
checkArgument(result != null);
// Note that for dynamic graph, different partitions in the same result have the same number
// of subpartitions.
if (result.getConsumingDistributionPattern() == DistributionPattern.POINTWISE) {
return new PointwiseBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions());
result.getPartitions()[0].getNumberOfSubpartitions(),
subpartitionBytesByPartitionIndex);
} else {
return new AllToAllBlockingResultInfo(
result.getId(),
result.getNumberOfAssignedPartitions(),
result.getPartitions()[0].getNumberOfSubpartitions(),
result.isBroadcast());
result.isBroadcast(),
subpartitionBytesByPartitionIndex);
}
}

Expand All @@ -937,6 +965,26 @@ SpeculativeExecutionHandler getSpeculativeExecutionHandler() {
return speculativeExecutionHandler;
}

private void tryUpdateResultInfo(IntermediateDataSetID id, DistributionPattern targetPattern) {
if (blockingResultInfos.containsKey(id)) {
BlockingResultInfo resultInfo = blockingResultInfos.get(id);
IntermediateResult result = getExecutionGraph().getAllIntermediateResults().get(id);

if ((targetPattern == DistributionPattern.ALL_TO_ALL && resultInfo.isPointwise())
|| (targetPattern == DistributionPattern.POINTWISE
&& !resultInfo.isPointwise())) {

BlockingResultInfo newInfo =
createFromIntermediateResult(
result, resultInfo.getSubpartitionBytesByPartitionIndex());

blockingResultInfos.put(id, newInfo);
} else if (targetPattern == DistributionPattern.ALL_TO_ALL) {
((AllToAllBlockingResultInfo) resultInfo).setBroadcast(result.isBroadcast());
}
}
}

private class DefaultBatchJobRecoveryContext implements BatchJobRecoveryContext {

private final FailoverStrategy restartStrategyOnResultConsumable =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.streaming.api.graph.ExecutionPlan;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.util.DynamicCodeLoadingException;

import java.util.concurrent.Executor;

Expand All @@ -46,7 +47,8 @@ public class AdaptiveExecutionHandlerFactory {
public static AdaptiveExecutionHandler create(
ExecutionPlan executionPlan,
ClassLoader userClassLoader,
Executor serializationExecutor) {
Executor serializationExecutor)
throws DynamicCodeLoadingException {
if (executionPlan instanceof JobGraph) {
return new NonAdaptiveExecutionHandler((JobGraph) executionPlan);
} else {
Expand Down
Loading