Skip to content

Commit 5d8a741

Browse files
vishalyaebyhr
authored andcommitted
Use TrinoRequestUser to get the user for the query
1 parent a45d249 commit 5d8a741

File tree

5 files changed

+70
-56
lines changed

5 files changed

+70
-56
lines changed

gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,18 @@
1313
*/
1414
package io.trino.gateway.ha.handler;
1515

16-
import com.google.common.base.Splitter;
1716
import com.google.common.io.CharStreams;
1817
import io.airlift.log.Logger;
1918
import jakarta.servlet.http.HttpServletRequest;
2019

2120
import java.io.InputStreamReader;
22-
import java.util.Base64;
2321
import java.util.List;
2422
import java.util.Optional;
2523
import java.util.regex.Matcher;
2624
import java.util.regex.Pattern;
2725

2826
import static com.google.common.base.Strings.isNullOrEmpty;
2927
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
30-
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
3128
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;
3229
import static java.nio.charset.StandardCharsets.UTF_8;
3330
import static java.util.Locale.ENGLISH;
@@ -56,41 +53,6 @@ public final class ProxyUtils
5653

5754
private ProxyUtils() {}
5855

59-
public static String getQueryUser(String userHeader, String authorization)
60-
{
61-
if (!isNullOrEmpty(userHeader)) {
62-
log.debug("User from header %s", USER_HEADER);
63-
return userHeader;
64-
}
65-
66-
log.debug("User from basic authentication");
67-
String user = "";
68-
if (authorization == null) {
69-
log.debug("No basic auth header found.");
70-
return user;
71-
}
72-
73-
int space = authorization.indexOf(' ');
74-
if ((space < 0) || !authorization.substring(0, space).equalsIgnoreCase("basic")) {
75-
log.error("Basic auth format is invalid");
76-
return user;
77-
}
78-
79-
String headerInfo = authorization.substring(space + 1).trim();
80-
if (isNullOrEmpty(headerInfo)) {
81-
log.error("Encoded value of basic auth doesn't exist");
82-
return user;
83-
}
84-
85-
String info = new String(Base64.getDecoder().decode(headerInfo), UTF_8);
86-
List<String> parts = Splitter.on(':').limit(2).splitToList(info);
87-
if (parts.size() < 1) {
88-
log.error("No user inside the basic auth text");
89-
return user;
90-
}
91-
return parts.get(0);
92-
}
93-
9456
public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
9557
{
9658
String path = request.getRequestURI();

gateway-ha/src/main/java/io/trino/gateway/ha/router/TrinoRequestUser.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ private Optional<String> extractUserFromBearerAuth(String header, String userFie
178178

179179
String token = header.substring(space + 1).trim();
180180

181-
if (header.split("\\.").length == 3) { //this is probably a JWS
181+
if (token.split("\\.").length == 3) { //this is probably a JWS
182182
log.debug("Trying to extract from JWS");
183183
try {
184-
DecodedJWT jwt = JWT.decode(header);
184+
DecodedJWT jwt = JWT.decode(token);
185185
if (jwt.getClaims().containsKey(userField)) {
186186
return Optional.of(jwt.getClaim(userField).asString());
187187
}

gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
3030
import io.trino.gateway.ha.router.QueryHistoryManager;
3131
import io.trino.gateway.ha.router.RoutingManager;
32+
import io.trino.gateway.ha.router.TrinoRequestUser;
3233
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;
3334
import jakarta.annotation.PreDestroy;
3435
import jakarta.servlet.http.HttpServletRequest;
@@ -43,6 +44,7 @@
4344
import java.util.Arrays;
4445
import java.util.HashMap;
4546
import java.util.List;
47+
import java.util.Optional;
4648
import java.util.concurrent.ExecutorService;
4749

4850
import static com.google.common.collect.ImmutableList.toImmutableList;
@@ -58,11 +60,8 @@
5860
import static io.airlift.http.client.Request.Builder.preparePost;
5961
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
6062
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
61-
import static io.trino.gateway.ha.handler.ProxyUtils.AUTHORIZATION;
6263
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
6364
import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER;
64-
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
65-
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
6665
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
6766
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
6867
import static jakarta.ws.rs.core.Response.Status.OK;
@@ -86,6 +85,7 @@ public class ProxyRequestHandler
8685
private final boolean addXForwardedHeaders;
8786
private final List<String> statementPaths;
8887
private final boolean includeClusterInfoInResponse;
88+
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
8989

9090
@Inject
9191
public ProxyRequestHandler(
@@ -97,6 +97,7 @@ public ProxyRequestHandler(
9797
this.httpClient = requireNonNull(httpClient, "httpClient is null");
9898
this.routingManager = requireNonNull(routingManager, "routingManager is null");
9999
this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null");
100+
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig());
100101
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
101102
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
102103
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
@@ -173,7 +174,8 @@ private void performRequest(
173174
FluentFuture<ProxyResponse> future = executeHttp(request);
174175

175176
if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
176-
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
177+
Optional<String> username = trinoRequestUserProvider.getInstance(servletRequest).getUser();
178+
future = future.transform(response -> recordBackendForQueryId(request, response, username), executor);
177179
if (includeClusterInfoInResponse) {
178180
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
179181
}
@@ -250,11 +252,11 @@ private static WebApplicationException badRequest(String message)
250252
.build());
251253
}
252254

253-
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response)
255+
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, Optional<String> username)
254256
{
255257
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());
256258

257-
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
259+
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);
258260

259261
log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());
260262

@@ -276,12 +278,12 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
276278
return response;
277279
}
278280

279-
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request)
281+
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request, Optional<String> username)
280282
{
281283
QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
282284
queryDetail.setBackendUrl(getRemoteTarget(request.getUri()));
283285
queryDetail.setCaptureTime(System.currentTimeMillis());
284-
queryDetail.setUser(getQueryUser(request.getHeader(USER_HEADER), request.getHeader(AUTHORIZATION)));
286+
username.ifPresent(queryDetail::setUser);
285287
queryDetail.setSource(request.getHeader(SOURCE_HEADER));
286288

287289
String queryText = new String(((StaticBodyGenerator) request.getBodyGenerator()).getBody(), UTF_8);

gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.List;
2323

2424
import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
25-
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
2625
import static org.assertj.core.api.Assertions.assertThat;
2726

2827
@TestInstance(Lifecycle.PER_CLASS)
@@ -61,11 +60,4 @@ void testExtractQueryIdFromUrl()
6160
assertThat(extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a", statementPaths))
6261
.isNull();
6362
}
64-
65-
@Test
66-
void testGetQueryUser()
67-
{
68-
assertThat(getQueryUser(null, "Basic dGVzdDoxMjPCow==")).isEqualTo("test");
69-
assertThat(getQueryUser("trino_user", "Basic dGVzdDoxMjPCow==")).isEqualTo("trino_user");
70-
}
7163
}

gateway-ha/src/test/java/io/trino/gateway/ha/router/TestTrinoRequestUser.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,25 @@
1313
*/
1414
package io.trino.gateway.ha.router;
1515

16+
import com.auth0.jwt.JWT;
17+
import com.auth0.jwt.algorithms.Algorithm;
1618
import io.airlift.json.JsonCodec;
19+
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
20+
import jakarta.servlet.http.HttpServletRequest;
1721
import org.junit.jupiter.api.Test;
1822

23+
import java.time.Instant;
24+
import java.util.Base64;
25+
import java.util.Date;
1926
import java.util.Optional;
2027

28+
import static com.auth0.jwt.algorithms.Algorithm.HMAC256;
29+
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
30+
import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION;
31+
import static java.nio.charset.StandardCharsets.UTF_8;
2132
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.Mockito.mock;
34+
import static org.mockito.Mockito.when;
2235

2336
final class TestTrinoRequestUser
2437
{
@@ -42,4 +55,49 @@ void testJsonCreator()
4255
assertThat(deserializedTrinoRequestUser.getUser()).isEqualTo(trinoRequestUser.getUser());
4356
assertThat(deserializedTrinoRequestUser.getUserInfo()).isEqualTo(trinoRequestUser.getUserInfo());
4457
}
58+
59+
@Test
60+
void testUserFromJwtToken()
61+
{
62+
String claimUserName = "username";
63+
String claimUserValue = "trino";
64+
65+
RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
66+
requestAnalyzerConfig.setTokenUserField(claimUserName);
67+
68+
Algorithm algorithm = HMAC256("random");
69+
70+
Instant expiryTime = Instant.now().plusSeconds(60);
71+
String token = JWT.create()
72+
.withIssuer("gateway")
73+
.withClaim(claimUserName, claimUserValue)
74+
.withExpiresAt(Date.from(expiryTime))
75+
.sign(algorithm);
76+
77+
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
78+
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
79+
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Bearer " + token);
80+
81+
TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);
82+
83+
assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue);
84+
}
85+
86+
@Test
87+
void testGetBasicAuthUser()
88+
{
89+
String username = "trino_user";
90+
String password = "don't care";
91+
String credentials = username + ":" + password;
92+
String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8));
93+
94+
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
95+
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
96+
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Basic " + encodedCredentials);
97+
98+
RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
99+
TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);
100+
101+
assertThat(trinoRequestUser.getUser()).hasValue(username);
102+
}
45103
}

0 commit comments

Comments
 (0)