|
| 1 | +/* |
| 2 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | + * you may not use this file except in compliance with the License. |
| 4 | + * You may obtain a copy of the License at |
| 5 | + * |
| 6 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | + * |
| 8 | + * Unless required by applicable law or agreed to in writing, software |
| 9 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | + * See the License for the specific language governing permissions and |
| 12 | + * limitations under the License. |
| 13 | + */ |
| 14 | +package io.trino.aws.proxy.server.rest; |
| 15 | + |
| 16 | +import com.google.common.collect.ImmutableList; |
| 17 | +import com.google.common.collect.ImmutableMap; |
| 18 | +import com.google.inject.BindingAnnotation; |
| 19 | +import com.google.inject.Inject; |
| 20 | +import com.google.inject.Key; |
| 21 | +import io.airlift.http.client.HttpStatus; |
| 22 | +import io.airlift.http.server.HttpServerConfig; |
| 23 | +import io.airlift.http.server.HttpServerInfo; |
| 24 | +import io.airlift.http.server.testing.TestingHttpServer; |
| 25 | +import io.airlift.node.NodeInfo; |
| 26 | +import io.trino.aws.proxy.server.remote.PathStyleRemoteS3Facade; |
| 27 | +import io.trino.aws.proxy.server.testing.TestingRemoteS3Facade; |
| 28 | +import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder; |
| 29 | +import io.trino.aws.proxy.server.testing.harness.BuilderFilter; |
| 30 | +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; |
| 31 | +import jakarta.servlet.http.HttpServlet; |
| 32 | +import jakarta.servlet.http.HttpServletRequest; |
| 33 | +import jakarta.servlet.http.HttpServletResponse; |
| 34 | +import org.junit.jupiter.api.Test; |
| 35 | +import software.amazon.awssdk.services.s3.S3Client; |
| 36 | +import software.amazon.awssdk.services.s3.model.S3Exception; |
| 37 | + |
| 38 | +import java.io.IOException; |
| 39 | +import java.lang.annotation.Retention; |
| 40 | +import java.lang.annotation.Target; |
| 41 | +import java.util.List; |
| 42 | +import java.util.Map; |
| 43 | +import java.util.Optional; |
| 44 | + |
| 45 | +import static com.google.common.collect.ImmutableMap.toImmutableMap; |
| 46 | +import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; |
| 47 | +import static java.lang.annotation.ElementType.FIELD; |
| 48 | +import static java.lang.annotation.ElementType.METHOD; |
| 49 | +import static java.lang.annotation.ElementType.PARAMETER; |
| 50 | +import static java.lang.annotation.RetentionPolicy.RUNTIME; |
| 51 | +import static java.util.Objects.requireNonNull; |
| 52 | +import static java.util.function.Function.identity; |
| 53 | +import static org.assertj.core.api.Assertions.assertThat; |
| 54 | +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; |
| 55 | + |
| 56 | +@TrinoAwsProxyTest(filters = TestProxiedErrorResponses.Filter.class) |
| 57 | +public class TestProxiedErrorResponses |
| 58 | +{ |
| 59 | + private final S3Client internalClient; |
| 60 | + |
| 61 | + /** |
| 62 | + * Status code taken from https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html |
| 63 | + */ |
| 64 | + private static final List<HttpStatus> STATUS_CODES = ImmutableList.of( |
| 65 | + HttpStatus.BAD_REQUEST, |
| 66 | + HttpStatus.FORBIDDEN, |
| 67 | + HttpStatus.NOT_FOUND, |
| 68 | + HttpStatus.METHOD_NOT_ALLOWED, |
| 69 | + HttpStatus.CONFLICT, |
| 70 | + HttpStatus.LENGTH_REQUIRED, |
| 71 | + HttpStatus.PRECONDITION_FAILED, |
| 72 | + HttpStatus.REQUEST_RANGE_NOT_SATISFIABLE, |
| 73 | + HttpStatus.INTERNAL_SERVER_ERROR, |
| 74 | + HttpStatus.NOT_IMPLEMENTED, |
| 75 | + HttpStatus.SERVICE_UNAVAILABLE); |
| 76 | + |
| 77 | + @Retention(RUNTIME) |
| 78 | + @Target({FIELD, PARAMETER, METHOD}) |
| 79 | + @BindingAnnotation |
| 80 | + public @interface ForErrorResponseTest {} |
| 81 | + |
| 82 | + public static class Filter |
| 83 | + implements BuilderFilter |
| 84 | + { |
| 85 | + @Override |
| 86 | + public Builder filter(Builder builder) |
| 87 | + { |
| 88 | + TestingHttpServer httpErrorResponseServer; |
| 89 | + try { |
| 90 | + httpErrorResponseServer = createTestingHttpErrorResponseServer(); |
| 91 | + httpErrorResponseServer.start(); |
| 92 | + } |
| 93 | + catch (Exception e) { |
| 94 | + throw new RuntimeException("Failed to start http error response server", e); |
| 95 | + } |
| 96 | + return builder.addModule(binder -> binder.bind(Key.get(TestingHttpServer.class, ForErrorResponseTest.class)).toInstance(httpErrorResponseServer)); |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + @Inject |
| 101 | + public TestProxiedErrorResponses(S3Client internalClient, TestingRemoteS3Facade delegatingFacade, @ForErrorResponseTest TestingHttpServer httpErrorResponseServer) |
| 102 | + { |
| 103 | + this.internalClient = requireNonNull(internalClient, "internal client is null"); |
| 104 | + delegatingFacade.setDelegate(new PathStyleRemoteS3Facade((_, _) -> httpErrorResponseServer.getBaseUrl().getHost(), false, Optional.of(httpErrorResponseServer.getBaseUrl().getPort()))); |
| 105 | + } |
| 106 | + |
| 107 | + @Test |
| 108 | + public void test() |
| 109 | + { |
| 110 | + for (HttpStatus status : STATUS_CODES) { |
| 111 | + assertThrownAwsError(status); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + private void assertThrownAwsError(HttpStatus status) |
| 116 | + { |
| 117 | + assertThatExceptionOfType(S3Exception.class).isThrownBy(() -> getFileFromStorage(internalClient, "status", String.valueOf(status.code()))) |
| 118 | + .satisfies( |
| 119 | + exception -> assertThat(exception.statusCode()).isEqualTo(status.code()), |
| 120 | + exception -> assertThat(exception.awsErrorDetails().errorCode()).isEqualTo(status.reason())); |
| 121 | + } |
| 122 | + |
| 123 | + private static TestingHttpServer createTestingHttpErrorResponseServer() |
| 124 | + throws IOException |
| 125 | + { |
| 126 | + NodeInfo nodeInfo = new NodeInfo("test"); |
| 127 | + HttpServerConfig config = new HttpServerConfig().setHttpPort(0); |
| 128 | + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); |
| 129 | + return new TestingHttpServer(httpServerInfo, nodeInfo, config, new HttpErrorResponseServlet(), ImmutableMap.of()); |
| 130 | + } |
| 131 | + |
| 132 | + private static class HttpErrorResponseServlet |
| 133 | + extends HttpServlet |
| 134 | + { |
| 135 | + private static final String RESPONSE_TEMPLATE = """ |
| 136 | +<?xml version="1.0" encoding="UTF-8"?> |
| 137 | +<Error> |
| 138 | + <Code>%s</Code> |
| 139 | + <Message>Error Message</Message> |
| 140 | + <Resource>%s</Resource> |
| 141 | + <RequestId>123</RequestId> |
| 142 | +</Error>"""; |
| 143 | + |
| 144 | + private static final Map<String, HttpStatus> PATH_STATUS_CODE_MAPPING = STATUS_CODES.stream().collect(toImmutableMap(status -> "/status/%d".formatted(status.code()), identity())); |
| 145 | + |
| 146 | + @Override |
| 147 | + protected void doGet(HttpServletRequest req, HttpServletResponse resp) |
| 148 | + throws IOException |
| 149 | + { |
| 150 | + String path = req.getPathInfo(); |
| 151 | + if (PATH_STATUS_CODE_MAPPING.containsKey(path)) { |
| 152 | + HttpStatus status = PATH_STATUS_CODE_MAPPING.get(path); |
| 153 | + resp.setStatus(status.code()); |
| 154 | + resp.getWriter().write(RESPONSE_TEMPLATE.formatted(status.reason(), path)); |
| 155 | + } |
| 156 | + else { |
| 157 | + resp.setStatus(HttpServletResponse.SC_NOT_FOUND); |
| 158 | + } |
| 159 | + } |
| 160 | + } |
| 161 | +} |
0 commit comments