Skip to content

Commit a11e5a4

Browse files
siminyouebyhr
authored andcommitted
Add proxy handler for PUT request
1 parent 62bbe94 commit a11e5a4

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import static io.airlift.http.client.Request.Builder.prepareDelete;
5959
import static io.airlift.http.client.Request.Builder.prepareGet;
6060
import static io.airlift.http.client.Request.Builder.preparePost;
61+
import static io.airlift.http.client.Request.Builder.preparePut;
6162
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
6263
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
6364
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
@@ -140,6 +141,17 @@ public void postRequest(
140141
performRequest(remoteUri, servletRequest, asyncResponse, request);
141142
}
142143

144+
public void putRequest(
145+
String statement,
146+
HttpServletRequest servletRequest,
147+
AsyncResponse asyncResponse,
148+
URI remoteUri)
149+
{
150+
Request.Builder request = preparePut()
151+
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
152+
performRequest(remoteUri, servletRequest, asyncResponse, request);
153+
}
154+
143155
private void performRequest(
144156
URI remoteUri,
145157
HttpServletRequest servletRequest,

gateway-ha/src/main/java/io/trino/gateway/proxyserver/RouteToBackendResource.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jakarta.ws.rs.DELETE;
2121
import jakarta.ws.rs.GET;
2222
import jakarta.ws.rs.POST;
23+
import jakarta.ws.rs.PUT;
2324
import jakarta.ws.rs.Path;
2425
import jakarta.ws.rs.container.AsyncResponse;
2526
import jakarta.ws.rs.container.Suspended;
@@ -85,4 +86,15 @@ public void deleteHandler(
8586
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
8687
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, URI.create(remoteUri));
8788
}
89+
90+
@PUT
91+
public void putHandler(
92+
String body,
93+
@Context HttpServletRequest servletRequest,
94+
@Suspended AsyncResponse asyncResponse)
95+
{
96+
MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(servletRequest, body);
97+
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
98+
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
99+
}
88100
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.proxyserver;
15+
16+
import io.trino.gateway.ha.HaGatewayLauncher;
17+
import io.trino.gateway.ha.HaGatewayTestUtils;
18+
import okhttp3.MediaType;
19+
import okhttp3.OkHttpClient;
20+
import okhttp3.Request;
21+
import okhttp3.RequestBody;
22+
import okhttp3.Response;
23+
import okhttp3.mockwebserver.Dispatcher;
24+
import okhttp3.mockwebserver.MockResponse;
25+
import okhttp3.mockwebserver.MockWebServer;
26+
import okhttp3.mockwebserver.RecordedRequest;
27+
import org.junit.jupiter.api.AfterAll;
28+
import org.junit.jupiter.api.BeforeAll;
29+
import org.junit.jupiter.api.Test;
30+
import org.junit.jupiter.api.TestInstance;
31+
32+
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
33+
import static com.google.common.net.MediaType.JSON_UTF_8;
34+
import static io.trino.gateway.ha.HaGatewayTestUtils.buildGatewayConfigAndSeedDb;
35+
import static io.trino.gateway.ha.HaGatewayTestUtils.prepareMockBackend;
36+
import static io.trino.gateway.ha.HaGatewayTestUtils.setUpBackend;
37+
import static org.assertj.core.api.Assertions.assertThat;
38+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
39+
40+
@TestInstance(PER_CLASS)
41+
final class TestProxyRequestHandler
42+
{
43+
private final OkHttpClient httpClient = new OkHttpClient();
44+
private final MockWebServer mockTrinoServer = new MockWebServer();
45+
46+
private final int routerPort = 21001 + (int) (Math.random() * 1000);
47+
private final int customBackendPort = 21000 + (int) (Math.random() * 1000);
48+
49+
private static final String OK = "OK";
50+
private static final int NOT_FOUND = 404;
51+
private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");
52+
53+
private final String customPutEndpoint = "/v1/custom"; // this is enabled in test-config-template.yml
54+
private final String healthCheckEndpoint = "/v1/info";
55+
56+
@BeforeAll
57+
void setup()
58+
throws Exception
59+
{
60+
prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response");
61+
mockTrinoServer.setDispatcher(new Dispatcher() {
62+
@Override
63+
public MockResponse dispatch(RecordedRequest request)
64+
{
65+
if (request.getPath().equals(healthCheckEndpoint)) {
66+
return new MockResponse().setResponseCode(200)
67+
.setHeader(CONTENT_TYPE, JSON_UTF_8)
68+
.setBody("{\"starting\": false}");
69+
}
70+
71+
if (request.getMethod().equals("PUT") && request.getPath().equals(customPutEndpoint)) {
72+
return new MockResponse().setResponseCode(200)
73+
.setHeader(CONTENT_TYPE, JSON_UTF_8)
74+
.setBody(OK);
75+
}
76+
77+
return new MockResponse().setResponseCode(NOT_FOUND);
78+
}
79+
});
80+
81+
HaGatewayTestUtils.TestConfig testConfig = buildGatewayConfigAndSeedDb(routerPort, "test-config-template.yml");
82+
83+
String[] args = {testConfig.configFilePath()};
84+
HaGatewayLauncher.main(args);
85+
86+
setUpBackend("custom", "http://localhost:" + customBackendPort, "externalUrl", true, "adhoc", routerPort);
87+
}
88+
89+
@AfterAll
90+
void cleanup()
91+
throws Exception
92+
{
93+
mockTrinoServer.shutdown();
94+
}
95+
96+
@Test
97+
void testPutRequestHandler()
98+
throws Exception
99+
{
100+
String url = "http://localhost:" + routerPort + customPutEndpoint;
101+
RequestBody requestBody = RequestBody.create("SELECT 1", MEDIA_TYPE);
102+
103+
Request putRequest = new Request.Builder().url(url).put(requestBody).build();
104+
try (Response response = httpClient.newCall(putRequest).execute()) {
105+
assertThat(response.body()).isNotNull();
106+
assertThat(response.body().string()).isEqualTo(OK);
107+
}
108+
109+
Request postRequest = new Request.Builder().url(url).post(requestBody).build();
110+
try (Response response = httpClient.newCall(postRequest).execute()) {
111+
assertThat(response.code()).isEqualTo(NOT_FOUND);
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)