diff --git a/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java b/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java index f67495611..c40559e82 100644 --- a/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java +++ b/tracing-jersey/src/test/java/com/palantir/tracing/jersey/TraceEnrichingFilterTest.java @@ -16,6 +16,7 @@ package com.palantir.tracing.jersey; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -27,6 +28,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.Lists; +import com.palantir.tracing.Trace; import com.palantir.tracing.TraceSampler; import com.palantir.tracing.Tracer; import com.palantir.tracing.Tracers; @@ -38,6 +41,17 @@ import io.dropwizard.Configuration; import io.dropwizard.setup.Environment; import io.dropwizard.testing.junit.DropwizardAppRule; +import java.io.IOException; +import java.util.EnumSet; +import java.util.List; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.ws.rs.Consumes; import javax.ws.rs.GET; import javax.ws.rs.POST; @@ -46,7 +60,9 @@ import javax.ws.rs.client.Client; import javax.ws.rs.client.Entity; import javax.ws.rs.client.WebTarget; +import javax.ws.rs.container.AsyncResponse; import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.container.Suspended; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.UriInfo; @@ -54,6 +70,7 @@ import org.junit.After; import org.junit.Before; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Captor; @@ -81,8 +98,11 @@ public final class TraceEnrichingFilterTest { private WebTarget target; + private static final List unexpectedTraces = Lists.newCopyOnWriteArrayList(); + @Before public void before() { + unexpectedTraces.clear(); MockitoAnnotations.initMocks(this); String endpointUri = "http://localhost:" + APP.getLocalPort(); JerseyClientBuilder builder = new JerseyClientBuilder(); @@ -102,6 +122,11 @@ public void before() { @After public void after() { Tracer.unsubscribe(""); + try { + assertThat(unexpectedTraces, is(empty())); + } finally { + unexpectedTraces.clear(); + } } @Test @@ -156,6 +181,17 @@ public void testTraceState_withoutRequestHeadersGeneratesValidTraceResponseHeade assertThat(spanCaptor.getValue().getOperation(), is("GET /trace")); } + @Test + @Ignore("https://github.com/palantir/tracing-java/issues/28") + public void testTraceState_asyncDoesNotLeakState() { + Response response = target.path("/trace/async").request().get(); + assertThat(response.getHeaderString(TraceHttpHeaders.TRACE_ID), not(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.PARENT_SPAN_ID), is(nullValue())); + assertThat(response.getHeaderString(TraceHttpHeaders.SPAN_ID), is(nullValue())); + verify(observer).consume(spanCaptor.capture()); + assertThat(spanCaptor.getValue().getOperation(), is("GET /trace/async")); + } + @Test public void testTraceState_withSamplingHeaderWithoutTraceIdDoesNotUseTraceSampler() { target.path("/trace").request() @@ -209,6 +245,29 @@ public void testFilter_setsMdcIfTraceIdHeaderIsNotePresent() throws Exception { public static class TracingTestServer extends Application { @Override public final void run(Configuration config, final Environment env) throws Exception { + FilterRegistration.Dynamic filter = env.servlets().addFilter("trace-leak-validator", new Filter() { + @Override + public void init(FilterConfig filterConfig) {} + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + if (Tracer.hasTraceId()) { + unexpectedTraces.add(Tracer.getAndClearTrace()); + } + try { + chain.doFilter(request, response); + } finally { + if (Tracer.hasTraceId()) { + unexpectedTraces.add(Tracer.getAndClearTrace()); + } + } + } + + @Override + public void destroy() {} + }); + filter.addMappingForUrlPatterns(EnumSet.allOf(DispatcherType.class), false, "/*"); env.jersey().register(new TraceEnrichingFilter()); env.jersey().register(new TracingTestResource()); } @@ -223,6 +282,11 @@ public void postTraceOperation() {} @Override public void getTraceWithPathParam() {} + + @Override + public void getAsync(AsyncResponse asyncResponse) { + new Thread(() -> asyncResponse.resume("complete")).start(); + } } @Path("/") @@ -240,5 +304,9 @@ public interface TracingTestService { @GET @Path("/trace/{param}") void getTraceWithPathParam(); + + @GET + @Path("/trace/async") + void getAsync(@Suspended AsyncResponse asyncResponse); } }