Skip to content

Commit 3cd6897

Browse files
committed
Return proxied response body for all status codes
1 parent 3c9b92e commit 3cd6897

File tree

3 files changed

+175
-4
lines changed

3 files changed

+175
-4
lines changed

trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/StreamingResponseHandler.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import com.google.common.base.Throwables;
1717
import com.google.common.collect.ImmutableMap;
1818
import io.airlift.http.client.HeaderName;
19-
import io.airlift.http.client.HttpStatus;
2019
import io.airlift.http.client.Request;
2120
import io.airlift.http.client.Response;
2221
import io.airlift.http.client.ResponseHandler;
@@ -78,9 +77,7 @@ public Void handle(Request request, Response response)
7877
};
7978

8079
jakarta.ws.rs.core.Response.ResponseBuilder responseBuilder = jakarta.ws.rs.core.Response.status(response.getStatusCode());
81-
if (HttpStatus.familyForStatusCode(response.getStatusCode()) == HttpStatus.Family.SUCCESSFUL) {
82-
responseBuilder.entity(streamingOutput);
83-
}
80+
responseBuilder.entity(streamingOutput);
8481
response.getHeaders()
8582
.keySet()
8683
.stream()

trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java

+13
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import java.time.Duration;
4545
import java.util.Comparator;
4646
import java.util.List;
47+
import java.util.UUID;
4748
import java.util.concurrent.CopyOnWriteArrayList;
4849
import java.util.concurrent.ExecutorService;
4950
import java.util.concurrent.Executors;
@@ -54,6 +55,7 @@
5455
import static com.google.common.collect.ImmutableList.toImmutableList;
5556
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
5657
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
58+
import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3;
5759
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
5860
import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage;
5961
import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket;
@@ -249,6 +251,17 @@ public void testPathsNeedingEscaping()
249251
internalClient.deleteBucket(r -> r.bucket(bucket));
250252
}
251253

254+
@Test
255+
public void testKeyOrBucketDoesNotExist()
256+
{
257+
assertFileNotInS3(internalClient, UUID.randomUUID().toString(), UUID.randomUUID().toString());
258+
259+
String newBucketName = "empty-bucket";
260+
remoteClient.createBucket(r -> r.bucket(newBucketName));
261+
262+
assertFileNotInS3(internalClient, newBucketName, UUID.randomUUID().toString());
263+
}
264+
252265
private static String buildLine(int partNumber)
253266
{
254267
// min multi-part is 5MB
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.aws.proxy.server.rest;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import com.google.common.collect.ImmutableMap;
18+
import com.google.inject.BindingAnnotation;
19+
import com.google.inject.Inject;
20+
import com.google.inject.Key;
21+
import io.airlift.http.client.HttpStatus;
22+
import io.airlift.http.server.HttpServerConfig;
23+
import io.airlift.http.server.HttpServerInfo;
24+
import io.airlift.http.server.testing.TestingHttpServer;
25+
import io.airlift.node.NodeInfo;
26+
import io.trino.aws.proxy.server.remote.PathStyleRemoteS3Facade;
27+
import io.trino.aws.proxy.server.testing.TestingRemoteS3Facade;
28+
import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder;
29+
import io.trino.aws.proxy.server.testing.harness.BuilderFilter;
30+
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
31+
import jakarta.servlet.http.HttpServlet;
32+
import jakarta.servlet.http.HttpServletRequest;
33+
import jakarta.servlet.http.HttpServletResponse;
34+
import org.junit.jupiter.api.Test;
35+
import software.amazon.awssdk.services.s3.S3Client;
36+
import software.amazon.awssdk.services.s3.model.S3Exception;
37+
38+
import java.io.IOException;
39+
import java.lang.annotation.Retention;
40+
import java.lang.annotation.Target;
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.Optional;
44+
45+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
46+
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
47+
import static java.lang.annotation.ElementType.FIELD;
48+
import static java.lang.annotation.ElementType.METHOD;
49+
import static java.lang.annotation.ElementType.PARAMETER;
50+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
51+
import static java.util.Objects.requireNonNull;
52+
import static java.util.function.Function.identity;
53+
import static org.assertj.core.api.Assertions.assertThat;
54+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
55+
56+
@TrinoAwsProxyTest(filters = TestProxiedErrorResponses.Filter.class)
57+
public class TestProxiedErrorResponses
58+
{
59+
private final S3Client internalClient;
60+
61+
/**
62+
* Status code taken from https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html
63+
*/
64+
private static final List<HttpStatus> STATUS_CODES = ImmutableList.of(
65+
HttpStatus.BAD_REQUEST,
66+
HttpStatus.FORBIDDEN,
67+
HttpStatus.NOT_FOUND,
68+
HttpStatus.METHOD_NOT_ALLOWED,
69+
HttpStatus.CONFLICT,
70+
HttpStatus.LENGTH_REQUIRED,
71+
HttpStatus.PRECONDITION_FAILED,
72+
HttpStatus.REQUEST_RANGE_NOT_SATISFIABLE,
73+
HttpStatus.INTERNAL_SERVER_ERROR,
74+
HttpStatus.NOT_IMPLEMENTED,
75+
HttpStatus.SERVICE_UNAVAILABLE);
76+
77+
@Retention(RUNTIME)
78+
@Target({FIELD, PARAMETER, METHOD})
79+
@BindingAnnotation
80+
public @interface ForErrorResponseTest {}
81+
82+
public static class Filter
83+
implements BuilderFilter
84+
{
85+
@Override
86+
public Builder filter(Builder builder)
87+
{
88+
TestingHttpServer httpErrorResponseServer;
89+
try {
90+
httpErrorResponseServer = createTestingHttpErrorResponseServer();
91+
httpErrorResponseServer.start();
92+
}
93+
catch (Exception e) {
94+
throw new RuntimeException("Failed to start http error response server", e);
95+
}
96+
return builder.addModule(binder -> binder.bind(Key.get(TestingHttpServer.class, ForErrorResponseTest.class)).toInstance(httpErrorResponseServer));
97+
}
98+
}
99+
100+
@Inject
101+
public TestProxiedErrorResponses(S3Client internalClient, TestingRemoteS3Facade delegatingFacade, @ForErrorResponseTest TestingHttpServer httpErrorResponseServer)
102+
{
103+
this.internalClient = requireNonNull(internalClient, "internal client is null");
104+
delegatingFacade.setDelegate(new PathStyleRemoteS3Facade((_, _) -> httpErrorResponseServer.getBaseUrl().getHost(), false, Optional.of(httpErrorResponseServer.getBaseUrl().getPort())));
105+
}
106+
107+
@Test
108+
public void test()
109+
{
110+
for (HttpStatus status : STATUS_CODES) {
111+
assertThrownAwsError(status);
112+
}
113+
}
114+
115+
private void assertThrownAwsError(HttpStatus status)
116+
{
117+
assertThatExceptionOfType(S3Exception.class).isThrownBy(() -> getFileFromStorage(internalClient, "status", String.valueOf(status.code())))
118+
.satisfies(
119+
exception -> assertThat(exception.statusCode()).isEqualTo(status.code()),
120+
exception -> assertThat(exception.awsErrorDetails().errorCode()).isEqualTo(status.reason()));
121+
}
122+
123+
private static TestingHttpServer createTestingHttpErrorResponseServer()
124+
throws IOException
125+
{
126+
NodeInfo nodeInfo = new NodeInfo("test");
127+
HttpServerConfig config = new HttpServerConfig().setHttpPort(0);
128+
HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo);
129+
return new TestingHttpServer(httpServerInfo, nodeInfo, config, new HttpErrorResponseServlet(), ImmutableMap.of());
130+
}
131+
132+
private static class HttpErrorResponseServlet
133+
extends HttpServlet
134+
{
135+
private static final String RESPONSE_TEMPLATE = """
136+
<?xml version="1.0" encoding="UTF-8"?>
137+
<Error>
138+
<Code>%s</Code>
139+
<Message>Error Message</Message>
140+
<Resource>%s</Resource>
141+
<RequestId>123</RequestId>
142+
</Error>""";
143+
144+
private static final Map<String, HttpStatus> PATH_STATUS_CODE_MAPPING = STATUS_CODES.stream().collect(toImmutableMap(status -> "/status/%d".formatted(status.code()), identity()));
145+
146+
@Override
147+
protected void doGet(HttpServletRequest req, HttpServletResponse resp)
148+
throws IOException
149+
{
150+
String path = req.getPathInfo();
151+
if (PATH_STATUS_CODE_MAPPING.containsKey(path)) {
152+
HttpStatus status = PATH_STATUS_CODE_MAPPING.get(path);
153+
resp.setStatus(status.code());
154+
resp.getWriter().write(RESPONSE_TEMPLATE.formatted(status.reason(), path));
155+
}
156+
else {
157+
resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
158+
}
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)