Skip to content

Commit 4ca7e89

Browse files
authored
Add ServerErrorHandler to inject dependencies in annotations (#5446)
Motivation: - It seems good to directly inject ServerErrorHandler via a bean. - #5440 Modifications: - Introduced `Optional<List<ServerErrorHandler>>` serverErrorHandlers in the `armeriaServer` method to enable the injection of ServerErrorHandler beans. - Updated the `configureServerWithArmeriaSettings` method to incorporate error handlers. Result: - Users can now define custom ServerErrorHandler beans and utilize an Armeria server that has been pre-configured with these user-defined error handlers. - Closes #5443
1 parent b1a1044 commit 4ca7e89

File tree

5 files changed

+103
-2
lines changed

5 files changed

+103
-2
lines changed

spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/internal/spring/ArmeriaConfigurationUtil.java

+3
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
6565
import com.linecorp.armeria.server.HttpService;
6666
import com.linecorp.armeria.server.ServerBuilder;
67+
import com.linecorp.armeria.server.ServerErrorHandler;
6768
import com.linecorp.armeria.server.encoding.EncodingService;
6869
import com.linecorp.armeria.server.metric.MetricCollectingService;
6970
import com.linecorp.armeria.server.metric.MetricCollectingServiceBuilder;
@@ -110,6 +111,7 @@ public static void configureServerWithArmeriaSettings(
110111
MeterIdPrefixFunction meterIdPrefixFunction,
111112
List<MetricCollectingServiceConfigurator> metricCollectingServiceConfigurators,
112113
List<DependencyInjector> dependencyInjectors,
114+
List<ServerErrorHandler> serverErrorHandlers,
113115
BeanFactory beanFactory) {
114116

115117
requireNonNull(server, "server");
@@ -203,6 +205,7 @@ public static void configureServerWithArmeriaSettings(
203205
if (settings.isEnableAutoInjection()) {
204206
server.dependencyInjector(SpringDependencyInjector.of(beanFactory), false);
205207
}
208+
serverErrorHandlers.forEach(server::errorHandler);
206209
}
207210

208211
private static void configureInternalService(ServerBuilder server, InternalServiceId serviceId,

spring/boot3-autoconfigure/src/main/java/com/linecorp/armeria/spring/AbstractArmeriaAutoConfiguration.java

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
4141
import com.linecorp.armeria.server.Server;
4242
import com.linecorp.armeria.server.ServerBuilder;
43+
import com.linecorp.armeria.server.ServerErrorHandler;
4344
import com.linecorp.armeria.server.ServerPort;
4445
import com.linecorp.armeria.server.docs.DocService;
4546
import com.linecorp.armeria.server.healthcheck.HealthCheckService;
@@ -73,6 +74,7 @@ public Server armeriaServer(
7374
Optional<List<ArmeriaServerConfigurator>> armeriaServerConfigurators,
7475
Optional<List<Consumer<ServerBuilder>>> armeriaServerBuilderConsumers,
7576
Optional<List<DependencyInjector>> dependencyInjectors,
77+
Optional<List<ServerErrorHandler>> serverErrorHandlers,
7678
BeanFactory beanFactory) {
7779

7880
if (!armeriaServerConfigurators.isPresent() &&
@@ -98,6 +100,7 @@ public Server armeriaServer(
98100
MeterIdPrefixFunction.ofDefault("armeria.server")),
99101
metricCollectingServiceConfigurators.orElse(ImmutableList.of()),
100102
dependencyInjectors.orElse(ImmutableList.of()),
103+
serverErrorHandlers.orElse(ImmutableList.of()),
101104
beanFactory);
102105

103106
return serverBuilder.build();

spring/boot3-autoconfigure/src/test/java/com/linecorp/armeria/spring/ArmeriaAutoConfigurationTest.java

+72-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
5555
import com.linecorp.armeria.common.metric.MeterIdPrefixFunction;
5656
import com.linecorp.armeria.server.Server;
57+
import com.linecorp.armeria.server.ServerErrorHandler;
5758
import com.linecorp.armeria.server.ServerPort;
5859
import com.linecorp.armeria.server.ServiceRequestContext;
5960
import com.linecorp.armeria.server.annotation.ExceptionHandlerFunction;
@@ -170,6 +171,26 @@ public MetricCollectingServiceConfigurator metricCollectingServiceConfigurator()
170171
return (statusCode >= 200 && statusCode < 400) || statusCode == 404;
171172
});
172173
}
174+
175+
@Bean
176+
public ServerErrorHandler serverErrorHandler1() {
177+
return (ctx, cause) -> {
178+
if (cause instanceof ArithmeticException) {
179+
return HttpResponse.of("ArithmeticException was handled by serverErrorHandler!");
180+
}
181+
return null;
182+
};
183+
}
184+
185+
@Bean
186+
public ServerErrorHandler serverErrorHandler2() {
187+
return (ctx, cause) -> {
188+
if (cause instanceof IllegalStateException) {
189+
return HttpResponse.of("IllegalStateException was handled by serverErrorHandler!");
190+
}
191+
return null;
192+
};
193+
}
173194
}
174195

175196
public static class IllegalArgumentExceptionHandler implements ExceptionHandlerFunction {
@@ -221,6 +242,21 @@ public AggregatedHttpResponse getV2() {
221242
public JsonNode post(@RequestObject JsonNode jsonNode) {
222243
return jsonNode;
223244
}
245+
246+
@Get("/unhandled1")
247+
public AggregatedHttpResponse unhandled1() throws Exception {
248+
throw new ArithmeticException();
249+
}
250+
251+
@Get("/unhandled2")
252+
public AggregatedHttpResponse unhandled2() throws Exception {
253+
throw new IllegalStateException();
254+
}
255+
256+
@Get("/unhandled3")
257+
public AggregatedHttpResponse unhandled3() throws Exception {
258+
throw new IllegalAccessException();
259+
}
224260
}
225261

226262
public static class HelloGrpcService extends TestServiceImplBase {
@@ -294,7 +330,7 @@ void testAnnotatedService() throws Exception {
294330
@Test
295331
void testThriftService() throws Exception {
296332
final TestService.Iface client = ThriftClients.newClient(newUrl("h1c") + "/thrift",
297-
TestService.Iface.class);
333+
TestService.Iface.class);
298334
assertThat(client.hello("world")).isEqualTo("hello world");
299335

300336
final WebClient webClient = WebClient.of(newUrl("h1c"));
@@ -314,7 +350,7 @@ void testThriftService() throws Exception {
314350
@Test
315351
void testGrpcService() throws Exception {
316352
final TestServiceBlockingStub client = GrpcClients.newClient(newUrl("h2c") + '/',
317-
TestServiceBlockingStub.class);
353+
TestServiceBlockingStub.class);
318354
final HelloRequest request = HelloRequest.newBuilder()
319355
.setName("world")
320356
.build();
@@ -394,4 +430,38 @@ void testHealthCheckService() throws Exception {
394430
res = response.aggregate().get();
395431
assertThat(res.status()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
396432
}
433+
434+
/**
435+
* When a ServerErrorHandler @Bean is present,
436+
* Server.config().errorHandler() does not register a DefaultServerErrorHandler.
437+
* Since DefaultServerErrorHandler is not public, test were forced to compare toString.
438+
* Needs to be improved.
439+
*/
440+
@Test
441+
void testServerErrorHandlerRegistration() {
442+
assertThat(server.config().errorHandler().toString()).isNotEqualTo("INSTANCE");
443+
}
444+
445+
@Test
446+
void testServerErrorHandler() throws Exception {
447+
final WebClient client = WebClient.of(newUrl("h1c"));
448+
449+
// ArithmeticException will be handled by serverErrorHandler
450+
final HttpResponse response1 = client.get("/annotated/unhandled1");
451+
final AggregatedHttpResponse res1 = response1.aggregate().join();
452+
assertThat(res1.status()).isEqualTo(HttpStatus.OK);
453+
assertThat(res1.contentUtf8()).isEqualTo("ArithmeticException was handled by serverErrorHandler!");
454+
455+
// IllegalStateException will be handled by serverErrorHandler
456+
final HttpResponse response2 = client.get("/annotated/unhandled2");
457+
final AggregatedHttpResponse res2 = response2.aggregate().join();
458+
assertThat(res2.status()).isEqualTo(HttpStatus.OK);
459+
assertThat(res2.contentUtf8()).isEqualTo("IllegalStateException was handled by serverErrorHandler!");
460+
461+
// IllegalAccessException will be handled by DefaultServerErrorHandler which is used as the
462+
// final fallback when all customized handlers return null
463+
final HttpResponse response3 = client.get("/annotated/unhandled3");
464+
final AggregatedHttpResponse res3 = response3.aggregate().join();
465+
assertThat(res3.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
466+
}
397467
}

spring/boot3-webflux-autoconfigure/src/main/java/com/linecorp/armeria/spring/web/reactive/ArmeriaReactiveWebServerFactory.java

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import com.linecorp.armeria.server.Route;
6565
import com.linecorp.armeria.server.Server;
6666
import com.linecorp.armeria.server.ServerBuilder;
67+
import com.linecorp.armeria.server.ServerErrorHandler;
6768
import com.linecorp.armeria.server.ServerPort;
6869
import com.linecorp.armeria.spring.ArmeriaServerConfigurator;
6970
import com.linecorp.armeria.spring.ArmeriaSettings;
@@ -167,6 +168,7 @@ public WebServer getWebServer(HttpHandler httpHandler) {
167168
meterIdPrefixFunctionOrDefault(),
168169
findBeans(MetricCollectingServiceConfigurator.class),
169170
findBeans(DependencyInjector.class),
171+
findBeans(ServerErrorHandler.class),
170172
beanFactory);
171173
}
172174

spring/boot3-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ArmeriaReactiveWebServerFactoryTest.java

+23
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.junit.jupiter.api.Test;
2929
import org.junit.jupiter.params.ParameterizedTest;
3030
import org.junit.jupiter.params.provider.CsvSource;
31+
import org.springframework.beans.factory.config.BeanDefinition;
3132
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3233
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
3334
import org.springframework.beans.factory.support.RootBeanDefinition;
@@ -62,12 +63,14 @@
6263
import com.linecorp.armeria.common.HttpData;
6364
import com.linecorp.armeria.common.HttpHeaderNames;
6465
import com.linecorp.armeria.common.HttpMethod;
66+
import com.linecorp.armeria.common.HttpResponse;
6567
import com.linecorp.armeria.common.MediaType;
6668
import com.linecorp.armeria.common.RequestHeaders;
6769
import com.linecorp.armeria.common.metric.PrometheusMeterRegistries;
6870
import com.linecorp.armeria.internal.common.util.PortUtil;
6971
import com.linecorp.armeria.internal.testing.MockAddressResolverGroup;
7072
import com.linecorp.armeria.server.HttpStatusException;
73+
import com.linecorp.armeria.server.ServerErrorHandler;
7174
import com.linecorp.armeria.server.annotation.Get;
7275
import com.linecorp.armeria.server.annotation.Param;
7376
import com.linecorp.armeria.server.healthcheck.HealthChecker;
@@ -567,4 +570,24 @@ void testManagementPort() throws JsonProcessingException {
567570
.isEqualTo("/hello/foo");
568571
}
569572
}
573+
574+
@Test
575+
void testServerErrorHandlerRegistration() {
576+
beanFactory.registerBeanDefinition("armeriaSettings", new RootBeanDefinition(ArmeriaSettings.class));
577+
registerInternalServices(beanFactory);
578+
579+
// Add ServerErrorHandler @Bean which handles all exceptions and returns 200 with empty string content.
580+
final ServerErrorHandler handler = (ctx, req) -> HttpResponse.of("");
581+
final BeanDefinition rbd2 = new RootBeanDefinition(ServerErrorHandler.class, () -> handler);
582+
beanFactory.registerBeanDefinition("serverErrorHandler", rbd2);
583+
584+
final ArmeriaReactiveWebServerFactory factory = factory();
585+
runServer(factory, (req, res) -> {
586+
throw new IllegalArgumentException(); // Always raise exception handler
587+
}, server -> {
588+
final WebClient client = httpClient(server);
589+
final AggregatedHttpResponse res1 = client.post("/hello", "hello").aggregate().join();
590+
assertThat(res1.status()).isEqualTo(com.linecorp.armeria.common.HttpStatus.OK);
591+
});
592+
}
570593
}

0 commit comments

Comments
 (0)