Skip to content

Commit

Permalink
Remove SearchPhaseContext (#116471)
Browse files Browse the repository at this point in the history
The only production implementation of this thing is
AbstractSearchAsyncAction, no need to keep a separate interface around.
This makes the logic a lot more obvious in terms of the lifeycle of
"context" and how it's essentially just the "main" search phase.
Plus it outright saves a lot of code, even though it adds a little on
the test side.
  • Loading branch information
original-brownbear authored Nov 8, 2024
1 parent 0f9ac9d commit 66123cf
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
* The fan out and collect algorithm is traditionally used as the initial phase which can either be a query execution or collection of
* distributed frequencies
*/
abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> extends SearchPhase implements SearchPhaseContext {
abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> extends SearchPhase {
private static final float DEFAULT_INDEX_BOOST = 1.0f;
private final Logger logger;
private final NamedWriteableRegistry namedWriteableRegistry;
Expand Down Expand Up @@ -106,7 +106,8 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final boolean throttleConcurrentRequests;
private final AtomicBoolean requestCancelled = new AtomicBoolean();

private final List<Releasable> releasables = new ArrayList<>();
// protected for tests
protected final List<Releasable> releasables = new ArrayList<>();

AbstractSearchAsyncAction(
String name,
Expand Down Expand Up @@ -194,7 +195,9 @@ protected void notifyListShards(
);
}

@Override
/**
* Registers a {@link Releasable} that will be closed when the search request finishes or fails.
*/
public void addReleasable(Releasable releasable) {
releasables.add(releasable);
}
Expand Down Expand Up @@ -333,8 +336,12 @@ protected abstract void executePhaseOnShard(
SearchActionListener<Result> listener
);

@Override
public final void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
/**
* Processes the phase transition from on phase to another. This method handles all errors that happen during the initial run execution
* of the next phase. If there are no successful operations in the context when this method is executed the search is aborted and
* a response is returned to the user indicating that all shards have failed.
*/
protected void executeNextPhase(SearchPhase currentPhase, Supplier<SearchPhase> nextPhaseSupplier) {
/* This is the main search phase transition where we move to the next phase. If all shards
* failed or if there was a failure and partial results are not allowed, then we immediately
* fail. Otherwise we continue to the next phase.
Expand Down Expand Up @@ -470,8 +477,7 @@ protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget
* @param shardTarget the shard target for this failure
* @param e the failure reason
*/
@Override
public final void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Exception e) {
void onShardFailure(final int shardIndex, SearchShardTarget shardTarget, Exception e) {
if (TransportActions.isShardNotAvailableException(e)) {
// Groups shard not available exceptions under a generic exception that returns a SERVICE_UNAVAILABLE(503)
// temporary error.
Expand Down Expand Up @@ -568,32 +574,45 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
}
}

@Override
/**
* Returns the total number of shards to the current search across all indices
*/
public final int getNumShards() {
return results.getNumShards();
}

@Override
/**
* Returns a logger for this context to prevent each individual phase to create their own logger.
*/
public final Logger getLogger() {
return logger;
}

@Override
/**
* Returns the currently executing search task
*/
public final SearchTask getTask() {
return task;
}

@Override
/**
* Returns the currently executing search request
*/
public final SearchRequest getRequest() {
return request;
}

@Override
/**
* Returns the targeted {@link OriginalIndices} for the provided {@code shardIndex}.
*/
public OriginalIndices getOriginalIndices(int shardIndex) {
return shardIterators[shardIndex].getOriginalIndices();
}

@Override
/**
* Checks if the given context id is part of the point in time of this search (if exists).
* We should not release search contexts that belong to the point in time during or after searches.
*/
public boolean isPartOfPointInTime(ShardSearchContextId contextId) {
final PointInTimeBuilder pointInTimeBuilder = request.pointInTimeBuilder();
if (pointInTimeBuilder != null) {
Expand Down Expand Up @@ -630,7 +649,12 @@ boolean buildPointInTimeFromSearchResults() {
return false;
}

@Override
/**
* Builds and sends the final search response back to the user.
*
* @param internalSearchResponse the internal search response
* @param queryResults the results of the query phase
*/
public void sendSearchResponse(SearchResponseSections internalSearchResponse, AtomicArray<SearchPhaseResult> queryResults) {
ShardSearchFailure[] failures = buildShardFailures();
Boolean allowPartialResults = request.allowPartialSearchResults();
Expand All @@ -655,8 +679,14 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At
}
}

@Override
public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
/**
* This method will communicate a fatal phase failure back to the user. In contrast to a shard failure
* will this method immediately fail the search request and return the failure to the issuer of the request
* @param phase the phase that failed
* @param msg an optional message
* @param cause the cause of the phase failure
*/
public void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures()));
}

Expand All @@ -683,6 +713,19 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) {
listener.onFailure(exception);
}

/**
* Releases a search context with the given context ID on the node the given connection is connected to.
* @see org.elasticsearch.search.query.QuerySearchResult#getContextId()
* @see org.elasticsearch.search.fetch.FetchSearchResult#getContextId()
*
*/
void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection, OriginalIndices originalIndices) {
assert isPartOfPointInTime(contextId) == false : "Must not release point in time context [" + contextId + "]";
if (connection != null) {
searchTransportService.sendFreeContext(connection, contextId, originalIndices);
}
}

/**
* Executed once all shard results have been received and processed
* @see #onShardFailure(int, SearchShardTarget, Exception)
Expand All @@ -692,23 +735,29 @@ final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim()
executeNextPhase(this, this::getNextPhase);
}

@Override
/**
* Returns a connection to the node if connected otherwise and {@link org.elasticsearch.transport.ConnectTransportException} will be
* thrown.
*/
public final Transport.Connection getConnection(String clusterAlias, String nodeId) {
return nodeIdToConnection.apply(clusterAlias, nodeId);
}

@Override
public final SearchTransportService getSearchTransport() {
/**
* Returns the {@link SearchTransportService} to send shard request to other nodes
*/
public SearchTransportService getSearchTransport() {
return searchTransportService;
}

@Override
public final void execute(Runnable command) {
executor.execute(command);
}

@Override
public final void onFailure(Exception e) {
/**
* Notifies the top-level listener of the provided exception
*/
public void onFailure(Exception e) {
listener.onFailure(e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ final class CountedCollector<R extends SearchPhaseResult> {
private final SearchPhaseResults<R> resultConsumer;
private final CountDown counter;
private final Runnable onFinish;
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;

CountedCollector(SearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
CountedCollector(SearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, AbstractSearchAsyncAction<?> context) {
this.resultConsumer = resultConsumer;
this.counter = new CountDown(expectedOps);
this.onFinish = onFinish;
Expand All @@ -50,7 +50,7 @@ void onResult(R result) {
}

/**
* Escalates the failure via {@link SearchPhaseContext#onShardFailure(int, SearchShardTarget, Exception)}
* Escalates the failure via {@link AbstractSearchAsyncAction#onShardFailure(int, SearchShardTarget, Exception)}
* and then runs {@link #countDown()}
*/
void onFailure(final int shardIndex, @Nullable SearchShardTarget shardTarget, Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class DfsQueryPhase extends SearchPhase {
private final AggregatedDfs dfs;
private final List<DfsKnnResults> knnResults;
private final Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;
private final SearchTransportService searchTransportService;
private final SearchProgressListener progressListener;

Expand All @@ -54,7 +54,7 @@ final class DfsQueryPhase extends SearchPhase {
List<DfsKnnResults> knnResults,
SearchPhaseResults<SearchPhaseResult> queryResult,
Function<SearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context
AbstractSearchAsyncAction<?> context
) {
super("dfs_query");
this.progressListener = context.getTask().getProgressListener();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
* forwards to the next phase immediately.
*/
final class ExpandSearchPhase extends SearchPhase {
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;
private final SearchHits searchHits;
private final Supplier<SearchPhase> nextPhase;

ExpandSearchPhase(SearchPhaseContext context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
ExpandSearchPhase(AbstractSearchAsyncAction<?> context, SearchHits searchHits, Supplier<SearchPhase> nextPhase) {
super("expand");
this.context = context;
this.searchHits = searchHits;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@
* @see org.elasticsearch.index.mapper.LookupRuntimeFieldType
*/
final class FetchLookupFieldsPhase extends SearchPhase {
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;
private final SearchResponseSections searchResponse;
private final AtomicArray<SearchPhaseResult> queryResults;

FetchLookupFieldsPhase(SearchPhaseContext context, SearchResponseSections searchResponse, AtomicArray<SearchPhaseResult> queryResults) {
FetchLookupFieldsPhase(
AbstractSearchAsyncAction<?> context,
SearchResponseSections searchResponse,
AtomicArray<SearchPhaseResult> queryResults
) {
super("fetch_lookup_fields");
this.context = context;
this.searchResponse = searchResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
final class FetchSearchPhase extends SearchPhase {
private final AtomicArray<SearchPhaseResult> searchPhaseShardResults;
private final BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;
private final Logger logger;
private final SearchProgressListener progressListener;
private final AggregatedDfs aggregatedDfs;
Expand All @@ -47,7 +47,7 @@ final class FetchSearchPhase extends SearchPhase {
FetchSearchPhase(
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
AbstractSearchAsyncAction<?> context,
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
this(
Expand All @@ -66,7 +66,7 @@ final class FetchSearchPhase extends SearchPhase {
FetchSearchPhase(
SearchPhaseResults<SearchPhaseResult> resultConsumer,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
AbstractSearchAsyncAction<?> context,
@Nullable SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
public class RankFeaturePhase extends SearchPhase {

private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class);
private final SearchPhaseContext context;
private final AbstractSearchAsyncAction<?> context;
final SearchPhaseResults<SearchPhaseResult> queryPhaseResults;
final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
private final AggregatedDfs aggregatedDfs;
Expand All @@ -48,7 +48,7 @@ public class RankFeaturePhase extends SearchPhase {
RankFeaturePhase(
SearchPhaseResults<SearchPhaseResult> queryPhaseResults,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
AbstractSearchAsyncAction<?> context,
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
) {
super("rank-feature");
Expand Down Expand Up @@ -179,22 +179,25 @@ private void onPhaseDone(
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() {
@Override
public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores);
SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults(
reducedQueryPhase,
topResults
);
moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase);
}
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(
context::execute,
new ActionListener<>() {
@Override
public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores);
SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults(
reducedQueryPhase,
topResults
);
moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase);
}

@Override
public void onFailure(Exception e) {
context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e);
@Override
public void onFailure(Exception e) {
context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e);
}
}
});
);
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(
rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(),
rankResultListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ protected void doCheckNoMissingShards(String phaseName, SearchRequest request, G
/**
* Releases shard targets that are not used in the docsIdsToLoad.
*/
protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, SearchPhaseContext context) {
protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, AbstractSearchAsyncAction<?> context) {
// we only release search context that we did not fetch from, if we are not scrolling
// or using a PIT and if it has at least one hit that didn't make it to the global topDocs
if (searchPhaseResult == null) {
Expand Down
Loading

0 comments on commit 66123cf

Please sign in to comment.