Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;

import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;

import static com.google.common.util.concurrent.Futures.catching;
import static com.google.common.util.concurrent.Futures.withTimeout;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

public class AsyncResponseUtils
{
private AsyncResponseUtils() {}

public static <V> ListenableFuture<V> withFallbackAfterTimeout(ListenableFuture<V> future, Duration timeout, Supplier<V> fallback, Executor responseExecutor, ScheduledExecutorService timeoutExecutor)
public static <V> ListenableFuture<V> withFallbackAfterTimeout(ListenableFuture<V> future, Duration timeout, Supplier<V> fallback, ScheduledExecutorService timeoutExecutor)
{
return catching(withTimeout(future, timeout.toMillis(), MILLISECONDS, timeoutExecutor), TimeoutException.class, _ -> fallback.get(), responseExecutor);
return catching(withTimeout(future, timeout.toMillis(), MILLISECONDS, timeoutExecutor), TimeoutException.class, _ -> fallback.get(), directExecutor());
}
}
22 changes: 17 additions & 5 deletions core/trino-main/src/main/java/io/trino/server/TaskResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
import java.util.concurrent.ThreadLocalRandom;

import static com.google.common.collect.Iterables.transform;
import static com.google.common.util.concurrent.Futures.withTimeout;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.addTimeout;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
Expand All @@ -83,6 +82,7 @@
import static io.trino.server.InternalHeaders.TRINO_TASK_FAILED;
import static io.trino.server.InternalHeaders.TRINO_TASK_INSTANCE_ID;
import static io.trino.server.security.ResourceSecurity.AccessType.INTERNAL_ONLY;
import static jakarta.ws.rs.core.Response.status;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -221,9 +221,11 @@ public void getTaskInfo(
futureTaskInfo = Futures.transform(futureTaskInfo, TaskInfo::summarize, directExecutor());
}

ListenableFuture<Response> response = Futures.transform(futureTaskInfo, taskInfo ->
Response.ok(taskInfo).build(), directExecutor());
// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse, withTimeout(futureTaskInfo, timeout.toMillis(), MILLISECONDS, timeoutExecutor), responseExecutor);
bindAsyncResponse(asyncResponse, withFallbackAfterTimeout(response, timeout, () -> serviceUnavailable(timeout), timeoutExecutor), responseExecutor);
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -266,7 +268,9 @@ public void getTaskStatus(

// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse, withTimeout(futureTaskStatus, timeout.toMillis(), MILLISECONDS, timeoutExecutor), responseExecutor);

ListenableFuture<Response> response = Futures.transform(futureTaskStatus, taskStatus -> Response.ok(taskStatus).build(), directExecutor());
bindAsyncResponse(asyncResponse, withFallbackAfterTimeout(response, timeout, () -> serviceUnavailable(timeout), timeoutExecutor), responseExecutor);
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -367,14 +371,14 @@ public void getResults(
// For hard timeout, add an additional time to max wait for thread scheduling contention and GC
Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS);
bindAsyncResponse(asyncResponse,
withFallbackAfterTimeout(responseFuture, timeout, () -> createBufferResultResponse(pagesInputStreamFactory, taskWithResults, emptyBufferResults), responseExecutor, timeoutExecutor), responseExecutor);
withFallbackAfterTimeout(responseFuture, timeout, () -> createBufferResultResponse(pagesInputStreamFactory, taskWithResults, emptyBufferResults), timeoutExecutor), responseExecutor);
responseFuture.addListener(() -> readFromOutputBufferTime.add(Duration.nanosSince(start)), directExecutor());
}

@ResourceSecurity(INTERNAL_ONLY)
@GET
@Path("{taskId}/results/{bufferId}/{token}/acknowledge")
public void acknowledgeResults(
public Response acknowledgeResults(
@PathParam("taskId") TaskId taskId,
@PathParam("bufferId") PipelinedOutputBuffers.OutputBufferId bufferId,
@PathParam("token") long token)
Expand All @@ -383,6 +387,7 @@ public void acknowledgeResults(
requireNonNull(bufferId, "bufferId is null");

taskManager.acknowledgeTaskResults(taskId, bufferId, token);
return Response.ok().build();
}

@ResourceSecurity(INTERNAL_ONLY)
Expand Down Expand Up @@ -561,4 +566,11 @@ private static Response createBufferResultResponse(PagesInputStreamFactory pages
pagesInputStreamFactory.write(output, serializedPages))
.build();
}

private static Response serviceUnavailable(Duration timeout)
{
return status(Response.Status.SERVICE_UNAVAILABLE)
.entity("Timed out after waiting for " + timeout.convertToMostSuccinctTimeUnit())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Throwables;
import com.google.inject.Inject;
import io.airlift.jaxrs.ParsingException;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.InternalServerErrorException;
Expand Down Expand Up @@ -83,6 +84,9 @@ public Response toResponse(Throwable throwable)
case TimeoutException timeoutException -> plainTextError(Response.Status.REQUEST_TIMEOUT)
.entity("Error 408 Timeout: " + timeoutException.getMessage())
.build();
case ParsingException parsingException -> Response.status(Response.Status.BAD_REQUEST)
.entity(Throwables.getStackTraceAsString(parsingException))
.build();
case WebApplicationException webApplicationException -> webApplicationException.getResponse();
default -> {
ResponseBuilder responseBuilder = plainTextError(Response.Status.INTERNAL_SERVER_ERROR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ public void fail(Throwable throwable)
public synchronized void dispose()
{
exchangeDataSource.close();
lastResult = null;
}

public QueryId getQueryId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.concurrent.ScheduledExecutorService;

import static com.google.common.util.concurrent.Futures.transform;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.server.AsyncResponseUtils.withFallbackAfterTimeout;
import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
Expand Down Expand Up @@ -97,8 +98,8 @@ public void getAuthenticationToken(@PathParam("authId") UUID authId, @Suspended
// hang if the client retries the request. The response will timeout eventually.
ListenableFuture<TokenPoll> tokenFuture = tokenExchange.getTokenPoll(authId);
ListenableFuture<Response> responseFuture = withFallbackAfterTimeout(
transform(tokenFuture, OAuth2TokenExchangeResource::toResponse, responseExecutor),
MAX_POLL_TIME, () -> pendingResponse(request), responseExecutor, timeoutExecutor);
transform(tokenFuture, OAuth2TokenExchangeResource::toResponse, directExecutor()),
MAX_POLL_TIME, () -> pendingResponse(request), timeoutExecutor);
bindAsyncResponse(asyncResponse, responseFuture, responseExecutor);
}

Expand Down
Loading