From 4056fe5c3a5ae3c05fe290d941a6421d92265bc5 Mon Sep 17 00:00:00 2001 From: Jaeho Yoo <> Date: Wed, 17 Sep 2025 10:00:08 +0900 Subject: [PATCH 1/2] Reformat to match Airlift style --- .../proxyserver/TestProxyRequestHandler.java | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java b/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java index fece434d4..236ba7572 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java @@ -51,6 +51,10 @@ @TestInstance(PER_CLASS) final class TestProxyRequestHandler { + private static final String OK = "OK"; + private static final int NOT_FOUND = 404; + private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); + private final OkHttpClient httpClient = new OkHttpClient(); private final MockWebServer mockTrinoServer = new MockWebServer(); private final PostgreSQLContainer postgresql = createPostgreSqlContainer(); @@ -58,10 +62,6 @@ final class TestProxyRequestHandler private final int routerPort = 21001 + (int) (Math.random() * 1000); private final int customBackendPort = 21000 + (int) (Math.random() * 1000); - private static final String OK = "OK"; - private static final int NOT_FOUND = 404; - private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); - private final String customPutEndpoint = "/v1/custom"; // this is enabled in test-config-template.yml private final String healthCheckEndpoint = "/v1/info"; @@ -70,7 +70,8 @@ void setup() throws Exception { prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response"); - mockTrinoServer.setDispatcher(new Dispatcher() { + mockTrinoServer.setDispatcher(new Dispatcher() + { @Override public MockResponse dispatch(RecordedRequest request) { @@ -131,18 +132,18 @@ void testGetQueryDetailsFromRequest() { // A sample query longer than 200 characters to test against truncation. String longQuery = """ - SELECT - c.customer_name, - c.customer_region, - COUNT(o.order_id) AS total_orders, - SUM(o.order_value) AS total_revenue - FROM - hive.sales_data.customers AS c - JOIN - hive.sales_data.orders AS o - ON c.customer_id = o.customer_id - WHERE - o.order_date >= date '2023-01-01'"""; + SELECT + c.customer_name, + c.customer_region, + COUNT(o.order_id) AS total_orders, + SUM(o.order_value) AS total_revenue + FROM + hive.sales_data.customers AS c + JOIN + hive.sales_data.orders AS o + ON c.customer_id = o.customer_id + WHERE + o.order_date >= date '2023-01-01'"""; io.airlift.http.client.Request request = preparePost() .setUri(URI.create("http://localhost:" + routerPort + V1_STATEMENT_PATH)) From ebf075d4557fbde1db276dbdb8b4548d28a3f78f Mon Sep 17 00:00:00 2001 From: Jaeho Yoo <> Date: Wed, 17 Sep 2025 10:00:21 +0900 Subject: [PATCH 2/2] Add support for request and response compression --- .../proxyserver/ProxyRequestHandler.java | 12 ++--- .../proxyserver/ProxyResponseHandler.java | 36 +++++++++++-- .../proxyserver/TestProxyRequestHandler.java | 53 +++++++++++++++++++ 3 files changed, 92 insertions(+), 9 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java index 8da9713ca..c8a95e3a5 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java @@ -166,8 +166,7 @@ private void performRequest( for (String name : list(servletRequest.getHeaderNames())) { for (String value : list(servletRequest.getHeaders(name))) { // TODO: decide what should and shouldn't be forwarded - if (!name.equalsIgnoreCase("Accept-Encoding") - && !name.equalsIgnoreCase("Host") + if (!name.equalsIgnoreCase("Host") && (addXForwardedHeaders || !name.startsWith("X-Forwarded"))) { requestBuilder.addHeader(name, value); } @@ -269,7 +268,8 @@ private static WebApplicationException badRequest(String message) private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, Optional username, RoutingDestination routingDestination) { - log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body()); + String body = response.decompressedBody(); + log.debug("For Request [%s] got Response [%s]", request.getUri(), body); QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username); @@ -277,18 +277,18 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res if (response.statusCode() == OK.getStatusCode()) { try { - HashMap results = OBJECT_MAPPER.readValue(response.body(), HashMap.class); + HashMap results = OBJECT_MAPPER.readValue(body, HashMap.class); queryDetail.setQueryId(results.get("id")); routingManager.setBackendForQueryId(queryDetail.getQueryId(), queryDetail.getBackendUrl()); routingManager.setRoutingGroupForQueryId(queryDetail.getQueryId(), routingDestination.routingGroup()); log.debug("QueryId [%s] mapped with proxy [%s]", queryDetail.getQueryId(), queryDetail.getBackendUrl()); } catch (IOException e) { - log.error("Failed to get QueryId from response [%s] , Status code [%s]", response.body(), response.statusCode()); + log.error("Failed to get QueryId from response [%s] , Status code [%s]", body, response.statusCode()); } } else { - log.error("Non OK HTTP Status code with response [%s] , Status code [%s], user: [%s]", response.body(), response.statusCode(), username.orElse(null)); + log.error("Non OK HTTP Status code with response [%s] , Status code [%s], user: [%s]", body, response.statusCode(), username.orElse(null)); } queryDetail.setRoutingGroup(routingDestination.routingGroup()); queryDetail.setExternalUrl(routingDestination.externalUrl()); diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyResponseHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyResponseHandler.java index e20c18ada..1f636ba20 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyResponseHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyResponseHandler.java @@ -22,9 +22,12 @@ import io.trino.gateway.ha.config.ProxyResponseConfiguration; import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse; +import java.io.ByteArrayInputStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.io.InputStream; +import java.util.zip.GZIPInputStream; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public class ProxyResponseHandler @@ -47,7 +50,9 @@ public ProxyResponse handleException(Request request, Exception exception) public ProxyResponse handle(Request request, Response response) { try { - return new ProxyResponse(response.getStatusCode(), response.getHeaders(), new String(response.getInputStream().readNBytes((int) responseSize.toBytes()), StandardCharsets.UTF_8)); + // Store raw bytes to preserve compression + byte[] responseBodyBytes = response.getInputStream().readNBytes((int) responseSize.toBytes()); + return new ProxyResponse(response.getStatusCode(), response.getHeaders(), responseBodyBytes); } catch (IOException e) { throw new ProxyException("Failed reading response from remote Trino server", e); @@ -57,11 +62,36 @@ public ProxyResponse handle(Request request, Response response) public record ProxyResponse( int statusCode, ListMultimap headers, - String body) + byte[] body) { public ProxyResponse { requireNonNull(headers, "headers is null"); + requireNonNull(body, "body is null"); + } + + /** + * Get the response body as a decompressed string for JSON parsing and logging. + * Only call this when you need to parse the content, not when passing through + * to clients. + */ + public String decompressedBody() + { + // Check if the response is gzip-compressed + String contentEncoding = headers.get(HeaderName.of("Content-Encoding")).stream().findFirst().orElse(null); + + if ("gzip".equalsIgnoreCase(contentEncoding)) { + try (InputStream inputStream = new GZIPInputStream(new ByteArrayInputStream(body))) { + return new String(inputStream.readAllBytes(), UTF_8); + } + catch (IOException e) { + // If decompression fails, return the body as UTF-8 string + return new String(body, UTF_8); + } + } + + // Not compressed, convert bytes to string + return new String(body, UTF_8); } } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java b/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java index 236ba7572..b9a66ec4d 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/proxyserver/TestProxyRequestHandler.java @@ -81,6 +81,14 @@ public MockResponse dispatch(RecordedRequest request) .setBody("{\"starting\": false}"); } + if (request.getPath().equals(healthCheckEndpoint + "?test-compression")) { + // Return the Accept-Encoding header value for compression testing + String acceptEncoding = request.getHeader("Accept-Encoding"); + return new MockResponse().setResponseCode(200) + .setHeader(CONTENT_TYPE, JSON_UTF_8) + .setBody(acceptEncoding != null ? acceptEncoding : "null"); + } + if (request.getMethod().equals("PUT") && request.getPath().equals(customPutEndpoint)) { return new MockResponse().setResponseCode(200) .setHeader(CONTENT_TYPE, JSON_UTF_8) @@ -160,4 +168,49 @@ void testGetQueryDetailsFromRequest() assertThat(queryDetail.getSource()).isEqualTo("trino-cli"); assertThat(queryDetail.getBackendUrl()).isEqualTo("http://localhost:" + routerPort); } + + @Test + void testAcceptEncodingHeaderForwarding() + throws Exception + { + // Test that Accept-Encoding header is properly forwarded to backends + String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression"; + String expectedAcceptEncoding = "gzip, deflate, br"; + + Request request = new Request.Builder() + .url(url) + .get() + .addHeader("Accept-Encoding", expectedAcceptEncoding) + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + assertThat(response.code()).isEqualTo(200); + assertThat(response.body()).isNotNull(); + + // The mock backend returns the Accept-Encoding header value in the response body + assertThat(response.body().string()).isEqualTo(expectedAcceptEncoding); + } + } + + @Test + void testDefaultAcceptEncodingHeaderForwarding() + throws Exception + { + // Test that requests without explicit Accept-Encoding header work correctly + // Note: OkHttp automatically adds "Accept-Encoding: gzip" when none is specified + String url = "http://localhost:" + routerPort + healthCheckEndpoint + "?test-compression"; + + Request request = new Request.Builder() + .url(url) + .get() + .build(); // No explicit Accept-Encoding header + + try (Response response = httpClient.newCall(request).execute()) { + assertThat(response.code()).isEqualTo(200); + assertThat(response.body()).isNotNull(); + + // OkHttp automatically adds "Accept-Encoding: gzip" when none is specified + assertThat(response.body().string()).isEqualTo("gzip"); + } + } }