Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gateway-ha/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ routingRules:
rulesEngineEnabled: False
# rulesConfigPath: "src/main/resources/rules/routing_rules.yml"

routing:
# Enable or disable query history recording to database (default: true)
queryHistoryEnabled: true


dataStore:
jdbcUrl: jdbc:postgresql://localhost:5432/trino_gateway_db
user: trino_gateway_db_admin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public class RoutingConfiguration

private String defaultRoutingGroup = "adhoc";

private boolean queryHistoryEnabled = true;

public Duration getAsyncTimeout()
{
return asyncTimeout;
Expand Down Expand Up @@ -54,4 +56,14 @@ public void setDefaultRoutingGroup(String defaultRoutingGroup)
{
this.defaultRoutingGroup = defaultRoutingGroup;
}

public boolean isQueryHistoryEnabled()
{
return queryHistoryEnabled;
}

public void setQueryHistoryEnabled(boolean queryHistoryEnabled)
{
this.queryHistoryEnabled = queryHistoryEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class ProxyRequestHandler
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final ProxyResponseConfiguration proxyResponseConfiguration;
private final boolean queryHistoryEnabled;

@Inject
public ProxyRequestHandler(
Expand All @@ -106,6 +107,7 @@ public ProxyRequestHandler(
statementPaths = haGatewayConfiguration.getStatementPaths();
this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse();
proxyResponseConfiguration = haGatewayConfiguration.getProxyResponseConfiguration();
this.queryHistoryEnabled = haGatewayConfiguration.getRouting().isQueryHistoryEnabled();
}

@PreDestroy
Expand Down Expand Up @@ -292,7 +294,9 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
}
queryDetail.setRoutingGroup(routingDestination.routingGroup());
queryDetail.setExternalUrl(routingDestination.externalUrl());
queryHistoryManager.submitQueryDetail(queryDetail);
if (queryHistoryEnabled) {
queryHistoryManager.submitQueryDetail(queryDetail);
}
return response;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.config;

import io.airlift.units.Duration;
import org.junit.jupiter.api.Test;

import static java.util.concurrent.TimeUnit.MINUTES;
import static org.assertj.core.api.Assertions.assertThat;

class TestRoutingConfiguration
{
@Test
void testDefaultValues()
{
RoutingConfiguration routingConfiguration = new RoutingConfiguration();
assertThat(routingConfiguration.getAsyncTimeout()).isEqualTo(new Duration(2, MINUTES));
assertThat(routingConfiguration.isAddXForwardedHeaders()).isTrue();
assertThat(routingConfiguration.getDefaultRoutingGroup()).isEqualTo("adhoc");
assertThat(routingConfiguration.isQueryHistoryEnabled()).isTrue();
}

@Test
void testQueryHistoryEnabledSetter()
{
RoutingConfiguration routingConfiguration = new RoutingConfiguration();
assertThat(routingConfiguration.isQueryHistoryEnabled()).isTrue();

routingConfiguration.setQueryHistoryEnabled(false);
assertThat(routingConfiguration.isQueryHistoryEnabled()).isFalse();

routingConfiguration.setQueryHistoryEnabled(true);
assertThat(routingConfiguration.isQueryHistoryEnabled()).isTrue();
}

@Test
void testAllSetters()
{
RoutingConfiguration routingConfiguration = new RoutingConfiguration();

Duration customTimeout = new Duration(5, MINUTES);
routingConfiguration.setAsyncTimeout(customTimeout);
assertThat(routingConfiguration.getAsyncTimeout()).isEqualTo(customTimeout);

routingConfiguration.setAddXForwardedHeaders(false);
assertThat(routingConfiguration.isAddXForwardedHeaders()).isFalse();

routingConfiguration.setDefaultRoutingGroup("batch");
assertThat(routingConfiguration.getDefaultRoutingGroup()).isEqualTo("batch");

routingConfiguration.setQueryHistoryEnabled(false);
assertThat(routingConfiguration.isQueryHistoryEnabled()).isFalse();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.proxyserver;

import io.trino.gateway.ha.HaGatewayLauncher;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.testcontainers.containers.PostgreSQLContainer;

import java.io.File;
import java.util.List;
import java.util.Map;

import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.google.common.net.MediaType.JSON_UTF_8;
import static io.trino.gateway.ha.HaGatewayTestUtils.buildGatewayConfig;
import static io.trino.gateway.ha.HaGatewayTestUtils.prepareMockBackend;
import static io.trino.gateway.ha.HaGatewayTestUtils.setUpBackend;
import static io.trino.gateway.ha.handler.HttpUtils.V1_STATEMENT_PATH;
import static io.trino.gateway.ha.util.TestcontainersUtils.createPostgreSqlContainer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
final class TestProxyRequestHandlerQueryHistoryDisabled
{
private final OkHttpClient httpClient = new OkHttpClient();
private final MockWebServer mockTrinoServer = new MockWebServer();
private final PostgreSQLContainer postgresql = createPostgreSqlContainer();

private final int routerPort = 22001 + (int) (Math.random() * 1000);
private final int customBackendPort = 22000 + (int) (Math.random() * 1000);

private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");
private static final String TEST_QUERY_ID = "20240101_123456_00000_abcde";

private final String healthCheckEndpoint = "/v1/info";
private Jdbi jdbi;

@BeforeAll
void setup()
throws Exception
{
prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response");
mockTrinoServer.setDispatcher(new Dispatcher() {
@Override
public MockResponse dispatch(RecordedRequest request)
{
if (request.getPath().equals(healthCheckEndpoint)) {
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
.setBody("{\"starting\": false}");
}

if (request.getMethod().equals("POST") && request.getPath().equals(V1_STATEMENT_PATH)) {
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
.setBody("{\"id\": \"" + TEST_QUERY_ID + "\", \"stats\": {}}");
}

return new MockResponse().setResponseCode(404);
}
});

postgresql.start();

File testConfigFile = buildGatewayConfig(postgresql, routerPort, "test-config-with-query-history-disabled.yml");

String[] args = {testConfigFile.getAbsolutePath()};
HaGatewayLauncher.main(args);

setUpBackend("custom", "http://localhost:" + customBackendPort, "externalUrl", true, "adhoc", routerPort);

jdbi = Jdbi.create(postgresql.getJdbcUrl(), postgresql.getUsername(), postgresql.getPassword());
}

@AfterAll
void cleanup()
throws Exception
{
mockTrinoServer.shutdown();
}

@Test
void testQueryHistoryNotRecordedWhenDisabled()
throws Exception
{
String url = "http://localhost:" + routerPort + V1_STATEMENT_PATH;
String testQuery = "SELECT 1";
RequestBody requestBody = RequestBody.create(testQuery, MEDIA_TYPE);

Request postRequest = new Request.Builder()
.url(url)
.addHeader("X-Trino-User", "test-user")
.post(requestBody)
.build();

try (Response response = httpClient.newCall(postRequest).execute()) {
assertThat(response.isSuccessful()).isTrue();
assertThat(response.body()).isNotNull();
String responseBody = response.body().string();
assertThat(responseBody).contains(TEST_QUERY_ID);
}

// Verify that query history was NOT recorded in the database
List<Map<String, Object>> queryHistory = jdbi.withHandle(handle ->
handle.createQuery("SELECT * FROM query_history WHERE query_id = :queryId")
.bind("queryId", TEST_QUERY_ID)
.mapToMap()
.list());

assertThat(queryHistory).isEmpty();
}
}
Loading