Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@
*/
package io.trino.gateway.ha.handler;

import com.google.common.base.Splitter;
import com.google.common.io.CharStreams;
import io.airlift.log.Logger;
import jakarta.servlet.http.HttpServletRequest;

import java.io.InputStreamReader;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -56,41 +53,6 @@ public final class ProxyUtils

private ProxyUtils() {}

public static String getQueryUser(String userHeader, String authorization)
{
if (!isNullOrEmpty(userHeader)) {
log.debug("User from header %s", USER_HEADER);
return userHeader;
}

log.debug("User from basic authentication");
String user = "";
if (authorization == null) {
log.debug("No basic auth header found.");
return user;
}

int space = authorization.indexOf(' ');
if ((space < 0) || !authorization.substring(0, space).equalsIgnoreCase("basic")) {
log.error("Basic auth format is invalid");
return user;
}

String headerInfo = authorization.substring(space + 1).trim();
if (isNullOrEmpty(headerInfo)) {
log.error("Encoded value of basic auth doesn't exist");
return user;
}

String info = new String(Base64.getDecoder().decode(headerInfo), UTF_8);
List<String> parts = Splitter.on(':').limit(2).splitToList(info);
if (parts.size() < 1) {
log.error("No user inside the basic auth text");
return user;
}
return parts.get(0);
}

public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
{
String path = request.getRequestURI();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ private Optional<String> extractUserFromBearerAuth(String header, String userFie

String token = header.substring(space + 1).trim();

if (header.split("\\.").length == 3) { //this is probably a JWS
if (token.split("\\.").length == 3) { //this is probably a JWS
log.debug("Trying to extract from JWS");
try {
DecodedJWT jwt = JWT.decode(header);
DecodedJWT jwt = JWT.decode(token);
if (jwt.getClaims().containsKey(userField)) {
return Optional.of(jwt.getClaim(userField).asString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.router.TrinoRequestUser;
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -43,6 +44,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -58,11 +60,8 @@
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.gateway.ha.handler.ProxyUtils.AUTHORIZATION;
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
import static jakarta.ws.rs.core.Response.Status.OK;
Expand All @@ -86,6 +85,7 @@ public class ProxyRequestHandler
private final boolean addXForwardedHeaders;
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;

@Inject
public ProxyRequestHandler(
Expand All @@ -97,6 +97,7 @@ public ProxyRequestHandler(
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.routingManager = requireNonNull(routingManager, "routingManager is null");
this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null");
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig());
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
Expand Down Expand Up @@ -173,7 +174,8 @@ private void performRequest(
FluentFuture<ProxyResponse> future = executeHttp(request);

if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
Optional<String> username = trinoRequestUserProvider.getInstance(servletRequest).getUser();
future = future.transform(response -> recordBackendForQueryId(request, response, username), executor);
if (includeClusterInfoInResponse) {
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
}
Expand Down Expand Up @@ -250,11 +252,11 @@ private static WebApplicationException badRequest(String message)
.build());
}

private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response)
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, Optional<String> username)
{
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());

QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);

log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());

Expand All @@ -276,12 +278,12 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
return response;
}

public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request)
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request, Optional<String> username)
{
QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
queryDetail.setBackendUrl(getRemoteTarget(request.getUri()));
queryDetail.setCaptureTime(System.currentTimeMillis());
queryDetail.setUser(getQueryUser(request.getHeader(USER_HEADER), request.getHeader(AUTHORIZATION)));
username.ifPresent(queryDetail::setUser);
queryDetail.setSource(request.getHeader(SOURCE_HEADER));

String queryText = new String(((StaticBodyGenerator) request.getBodyGenerator()).getBody(), UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.List;

import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static org.assertj.core.api.Assertions.assertThat;

@TestInstance(Lifecycle.PER_CLASS)
Expand Down Expand Up @@ -61,11 +60,4 @@ void testExtractQueryIdFromUrl()
assertThat(extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a", statementPaths))
.isNull();
}

@Test
void testGetQueryUser()
{
assertThat(getQueryUser(null, "Basic dGVzdDoxMjPCow==")).isEqualTo("test");
assertThat(getQueryUser("trino_user", "Basic dGVzdDoxMjPCow==")).isEqualTo("trino_user");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@
*/
package io.trino.gateway.ha.router;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import io.airlift.json.JsonCodec;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.Test;

import java.time.Instant;
import java.util.Base64;
import java.util.Date;
import java.util.Optional;

import static com.auth0.jwt.algorithms.Algorithm.HMAC256;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

final class TestTrinoRequestUser
{
Expand All @@ -42,4 +55,49 @@ void testJsonCreator()
assertThat(deserializedTrinoRequestUser.getUser()).isEqualTo(trinoRequestUser.getUser());
assertThat(deserializedTrinoRequestUser.getUserInfo()).isEqualTo(trinoRequestUser.getUserInfo());
}

@Test
void testUserFromJwtToken()
{
String claimUserName = "username";
String claimUserValue = "trino";

RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
requestAnalyzerConfig.setTokenUserField(claimUserName);

Algorithm algorithm = HMAC256("random");

Instant expiryTime = Instant.now().plusSeconds(60);
String token = JWT.create()
.withIssuer("gateway")
.withClaim(claimUserName, claimUserValue)
.withExpiresAt(Date.from(expiryTime))
.sign(algorithm);

HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Bearer " + token);

TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);

assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue);
}

@Test
void testGetBasicAuthUser()
{
String username = "trino_user";
String password = "don't care";
String credentials = username + ":" + password;
String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8));

HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Basic " + encodedCredentials);

RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);

assertThat(trinoRequestUser.getUser()).hasValue(username);
}
}