Skip to content

Commit 2168a08

Browse files
committed
Use TrinoRequestUser to get the user for the query
1 parent a45d249 commit 2168a08

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ private Optional<String> extractUserFromBearerAuth(String header, String userFie
178178

179179
String token = header.substring(space + 1).trim();
180180

181-
if (header.split("\\.").length == 3) { //this is probably a JWS
181+
if (token.split("\\.").length == 3) { //this is probably a JWS
182182
log.debug("Trying to extract from JWS");
183183
try {
184-
DecodedJWT jwt = JWT.decode(header);
184+
DecodedJWT jwt = JWT.decode(token);
185185
if (jwt.getClaims().containsKey(userField)) {
186186
return Optional.of(jwt.getClaim(userField).asString());
187187
}

gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
3030
import io.trino.gateway.ha.router.QueryHistoryManager;
3131
import io.trino.gateway.ha.router.RoutingManager;
32+
import io.trino.gateway.ha.router.TrinoRequestUser;
3233
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;
3334
import jakarta.annotation.PreDestroy;
3435
import jakarta.servlet.http.HttpServletRequest;
@@ -58,11 +59,8 @@
5859
import static io.airlift.http.client.Request.Builder.preparePost;
5960
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
6061
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
61-
import static io.trino.gateway.ha.handler.ProxyUtils.AUTHORIZATION;
6262
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
6363
import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER;
64-
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
65-
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
6664
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
6765
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
6866
import static jakarta.ws.rs.core.Response.Status.OK;
@@ -86,6 +84,7 @@ public class ProxyRequestHandler
8684
private final boolean addXForwardedHeaders;
8785
private final List<String> statementPaths;
8886
private final boolean includeClusterInfoInResponse;
87+
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
8988

9089
@Inject
9190
public ProxyRequestHandler(
@@ -97,6 +96,7 @@ public ProxyRequestHandler(
9796
this.httpClient = requireNonNull(httpClient, "httpClient is null");
9897
this.routingManager = requireNonNull(routingManager, "routingManager is null");
9998
this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null");
99+
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig());
100100
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
101101
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
102102
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
@@ -173,7 +173,8 @@ private void performRequest(
173173
FluentFuture<ProxyResponse> future = executeHttp(request);
174174

175175
if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
176-
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
176+
String username = trinoRequestUserProvider.getInstance(servletRequest).getUser().orElse("Unknown");
177+
future = future.transform(response -> recordBackendForQueryId(request, response, username), executor);
177178
if (includeClusterInfoInResponse) {
178179
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
179180
}
@@ -250,11 +251,11 @@ private static WebApplicationException badRequest(String message)
250251
.build());
251252
}
252253

253-
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response)
254+
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, String username)
254255
{
255256
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());
256257

257-
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
258+
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);
258259

259260
log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());
260261

@@ -276,12 +277,12 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
276277
return response;
277278
}
278279

279-
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request)
280+
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request, String username)
280281
{
281282
QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
282283
queryDetail.setBackendUrl(getRemoteTarget(request.getUri()));
283284
queryDetail.setCaptureTime(System.currentTimeMillis());
284-
queryDetail.setUser(getQueryUser(request.getHeader(USER_HEADER), request.getHeader(AUTHORIZATION)));
285+
queryDetail.setUser(username);
285286
queryDetail.setSource(request.getHeader(SOURCE_HEADER));
286287

287288
String queryText = new String(((StaticBodyGenerator) request.getBodyGenerator()).getBody(), UTF_8);

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,22 @@
1313
*/
1414
package io.trino.gateway.ha.router;
1515

16+
import com.auth0.jwt.JWT;
17+
import com.auth0.jwt.algorithms.Algorithm;
1618
import io.airlift.json.JsonCodec;
19+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
20+
import jakarta.servlet.http.HttpServletRequest;
21+
import jakarta.ws.rs.core.HttpHeaders;
1722
import org.junit.jupiter.api.Test;
1823

24+
import java.time.Instant;
25+
import java.util.Date;
1926
import java.util.Optional;
2027

28+
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
2129
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.when;
2232

2333
final class TestTrinoRequestUser
2434
{
@@ -42,4 +52,31 @@ void testJsonCreator()
4252
assertThat(deserializedTrinoRequestUser.getUser()).isEqualTo(trinoRequestUser.getUser());
4353
assertThat(deserializedTrinoRequestUser.getUserInfo()).isEqualTo(trinoRequestUser.getUserInfo());
4454
}
55+
56+
@Test
57+
void testUserFromJwtToken()
58+
{
59+
String claimUserName = "username";
60+
String claimUserValue = "trino";
61+
62+
RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
63+
requestAnalyzerConfig.setTokenUserField(claimUserName);
64+
65+
Algorithm algorithm = Algorithm.HMAC256("random");
66+
67+
Instant expiryTime = Instant.now().plusSeconds(60);
68+
String token = JWT.create()
69+
.withIssuer("gateway")
70+
.withClaim(claimUserName, claimUserValue)
71+
.withExpiresAt(Date.from(expiryTime))
72+
.sign(algorithm);
73+
74+
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
75+
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
76+
when(mockRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + token);
77+
78+
TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);
79+
80+
assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue);
81+
}
4582
}

0 commit comments

Comments
 (0)