Skip to content

Commit d0c2b8a

Browse files
committed
Implement support for JWT-based authentication in the Trino Gateway.
1 parent 8b88412 commit d0c2b8a

File tree

16 files changed

+2475
-8
lines changed

16 files changed

+2475
-8
lines changed

gateway-ha/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@
131131
<artifactId>concurrent</artifactId>
132132
</dependency>
133133

134+
<dependency>
135+
<groupId>io.airlift</groupId>
136+
<artifactId>configuration</artifactId>
137+
</dependency>
138+
134139
<dependency>
135140
<groupId>io.airlift</groupId>
136141
<artifactId>http-client</artifactId>

gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthenticationConfiguration.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ public class AuthenticationConfiguration
1818
private String defaultType;
1919
private OAuthConfiguration oauth;
2020
private FormAuthConfiguration form;
21+
private JwtConfiguration jwt;
2122

22-
public AuthenticationConfiguration(String defaultType, OAuthConfiguration oauth, FormAuthConfiguration form)
23+
public AuthenticationConfiguration(String defaultType, OAuthConfiguration oauth, FormAuthConfiguration form, JwtConfiguration jwt)
2324
{
2425
this.defaultType = defaultType;
2526
this.oauth = oauth;
2627
this.form = form;
28+
this.jwt = jwt;
2729
}
2830

2931
public AuthenticationConfiguration() {}
@@ -57,4 +59,14 @@ public void setForm(FormAuthConfiguration form)
5759
{
5860
this.form = form;
5961
}
62+
63+
public JwtConfiguration getJwt()
64+
{
65+
return this.jwt;
66+
}
67+
68+
public void setJwt(JwtConfiguration jwt)
69+
{
70+
this.jwt = jwt;
71+
}
6072
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.config;
15+
16+
import io.airlift.configuration.validation.FileExists;
17+
18+
import java.io.File;
19+
import java.util.Optional;
20+
21+
public class JwtConfiguration
22+
{
23+
private String keyFile;
24+
private String requiredIssuer;
25+
private String requiredAudience;
26+
private String principalField = "sub";
27+
private Optional<String> userMappingPattern = Optional.empty();
28+
private Optional<File> userMappingFile = Optional.empty();
29+
30+
public JwtConfiguration() {}
31+
32+
public String getKeyFile()
33+
{
34+
return this.keyFile;
35+
}
36+
37+
public void setKeyFile(String keyFile)
38+
{
39+
this.keyFile = keyFile;
40+
}
41+
42+
public String getRequiredIssuer()
43+
{
44+
return this.requiredIssuer;
45+
}
46+
47+
public void setRequiredIssuer(String requiredIssuer)
48+
{
49+
this.requiredIssuer = requiredIssuer;
50+
}
51+
52+
public String getRequiredAudience()
53+
{
54+
return this.requiredAudience;
55+
}
56+
57+
public void setRequiredAudience(String requiredAudience)
58+
{
59+
this.requiredAudience = requiredAudience;
60+
}
61+
62+
public String getPrincipalField()
63+
{
64+
return this.principalField;
65+
}
66+
67+
public void setPrincipalField(String principalField)
68+
{
69+
this.principalField = principalField;
70+
}
71+
72+
public Optional<String> getUserMappingPattern()
73+
{
74+
return this.userMappingPattern;
75+
}
76+
77+
public void setUserMappingPattern(String userMappingPattern)
78+
{
79+
this.userMappingPattern = Optional.ofNullable(userMappingPattern);
80+
}
81+
82+
public Optional<@FileExists File> getUserMappingFile()
83+
{
84+
return this.userMappingFile;
85+
}
86+
87+
public void setUserMappingFile(File userMappingFile)
88+
{
89+
this.userMappingFile = Optional.ofNullable(userMappingFile);
90+
}
91+
}

gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import io.trino.gateway.ha.security.LbAuthorizer;
6161
import io.trino.gateway.ha.security.LbFilter;
6262
import io.trino.gateway.ha.security.LbFormAuthManager;
63+
import io.trino.gateway.ha.security.LbJwtAuthenticator;
64+
import io.trino.gateway.ha.security.LbJwtManager;
6365
import io.trino.gateway.ha.security.LbOAuthManager;
6466
import io.trino.gateway.ha.security.LbUnauthorizedHandler;
6567
import io.trino.gateway.ha.security.NoopAuthorizer;
@@ -83,6 +85,7 @@ public class HaGatewayProviderModule
8385
{
8486
private final LbOAuthManager oauthManager;
8587
private final LbFormAuthManager formAuthManager;
88+
private final LbJwtManager jwtManager;
8689
private final AuthorizationManager authorizationManager;
8790
private final ResourceSecurityDynamicFeature resourceSecurityDynamicFeature;
8891
private final HaGatewayConfiguration configuration;
@@ -110,6 +113,7 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration)
110113

111114
oauthManager = getOAuthManager(configuration);
112115
formAuthManager = getFormAuthManager(configuration);
116+
jwtManager = getJwtManager(configuration);
113117

114118
authorizationManager = new AuthorizationManager(configuration.getAuthorization(), presetUsers);
115119
resourceSecurityDynamicFeature = getAuthFilter(configuration);
@@ -146,6 +150,15 @@ private LbFormAuthManager getFormAuthManager(HaGatewayConfiguration configuratio
146150
return null;
147151
}
148152

153+
private LbJwtManager getJwtManager(HaGatewayConfiguration configuration)
154+
{
155+
AuthenticationConfiguration authenticationConfiguration = configuration.getAuthentication();
156+
if (authenticationConfiguration != null && authenticationConfiguration.getJwt() != null) {
157+
return new LbJwtManager(authenticationConfiguration.getJwt(), configuration.getPagePermissions());
158+
}
159+
return null;
160+
}
161+
149162
private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration config, Authorizer authorizer)
150163
{
151164
ImmutableList.Builder<ContainerRequestFilter> authFilters = ImmutableList.builder();
@@ -171,6 +184,14 @@ private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration c
171184
new LbUnauthorizedHandler(defaultType)));
172185
}
173186

187+
if (jwtManager != null) {
188+
authFilters.add(new LbFilter(
189+
new LbJwtAuthenticator(jwtManager, authorizationManager),
190+
authorizer,
191+
"Bearer",
192+
new LbUnauthorizedHandler(defaultType)));
193+
}
194+
174195
return new ChainedAuthFilter(authFilters.build());
175196
}
176197

@@ -203,6 +224,13 @@ public LbFormAuthManager getFormAuthentication()
203224
return this.formAuthManager;
204225
}
205226

227+
@Provides
228+
@Singleton
229+
public LbJwtManager getJwtAuthentication()
230+
{
231+
return this.jwtManager;
232+
}
233+
206234
@Provides
207235
@Singleton
208236
public AuthorizationManager getAuthorizationManager()

gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.trino.gateway.ha.domain.Result;
2020
import io.trino.gateway.ha.domain.request.RestLoginRequest;
2121
import io.trino.gateway.ha.security.LbFormAuthManager;
22+
import io.trino.gateway.ha.security.LbJwtManager;
2223
import io.trino.gateway.ha.security.LbOAuthManager;
2324
import io.trino.gateway.ha.security.LbPrincipal;
2425
import io.trino.gateway.ha.security.OidcCookie;
@@ -34,6 +35,7 @@
3435
import jakarta.ws.rs.WebApplicationException;
3536
import jakarta.ws.rs.core.Context;
3637
import jakarta.ws.rs.core.Cookie;
38+
import jakarta.ws.rs.core.HttpHeaders;
3739
import jakarta.ws.rs.core.MediaType;
3840
import jakarta.ws.rs.core.Response;
3941
import jakarta.ws.rs.core.SecurityContext;
@@ -56,12 +58,14 @@ public class LoginResource
5658

5759
private final LbOAuthManager oauthManager;
5860
private final LbFormAuthManager formAuthManager;
61+
private final LbJwtManager jwtManager;
5962

6063
@Inject
61-
public LoginResource(@Nullable LbOAuthManager oauthManager, @Nullable LbFormAuthManager formAuthManager)
64+
public LoginResource(@Nullable LbOAuthManager oauthManager, @Nullable LbFormAuthManager formAuthManager, @Nullable LbJwtManager jwtManager)
6265
{
6366
this.oauthManager = oauthManager;
6467
this.formAuthManager = formAuthManager;
68+
this.jwtManager = jwtManager;
6569
}
6670

6771
@GET
@@ -174,7 +178,11 @@ else if (oauthManager != null) {
174178
public Response loginType()
175179
{
176180
String loginType;
177-
if (formAuthManager != null) {
181+
// Check for JWT authentication first (highest priority)
182+
if (jwtManager != null) {
183+
loginType = "jwt";
184+
}
185+
else if (formAuthManager != null) {
178186
loginType = "form";
179187
}
180188
else if (oauthManager != null) {
@@ -185,4 +193,18 @@ else if (oauthManager != null) {
185193
}
186194
return Response.ok(Result.ok("Ok", loginType)).build();
187195
}
196+
197+
@POST
198+
@Path("token")
199+
@Consumes(MediaType.APPLICATION_JSON)
200+
@Produces(MediaType.APPLICATION_JSON)
201+
public Response token(@Context HttpHeaders headers)
202+
{
203+
String authHeaderVal = headers.getHeaderString(HttpHeaders.AUTHORIZATION);
204+
String bearerToken = null;
205+
if (authHeaderVal != null && authHeaderVal.startsWith("Bearer ")) {
206+
bearerToken = authHeaderVal.substring("Bearer ".length());
207+
}
208+
return Response.ok(Result.ok("Ok", bearerToken)).build();
209+
}
188210
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.gateway.ha.security;
15+
16+
import com.auth0.jwt.interfaces.Claim;
17+
import io.airlift.log.Logger;
18+
import io.trino.gateway.ha.security.util.AuthenticationException;
19+
import io.trino.gateway.ha.security.util.IdTokenAuthenticator;
20+
21+
import java.util.Map;
22+
import java.util.Optional;
23+
24+
import static io.trino.gateway.ha.security.UserMapping.createUserMapping;
25+
26+
public class LbJwtAuthenticator
27+
implements IdTokenAuthenticator
28+
{
29+
private static final Logger log = Logger.get(LbJwtAuthenticator.class);
30+
31+
private final LbJwtManager jwtManager;
32+
private final AuthorizationManager authorizationManager;
33+
private final UserMapping userMapping;
34+
35+
public LbJwtAuthenticator(LbJwtManager jwtManager,
36+
AuthorizationManager authorizationManager)
37+
{
38+
this.jwtManager = jwtManager;
39+
this.authorizationManager = authorizationManager;
40+
this.userMapping = createUserMapping(jwtManager.getUserMappingPattern(), jwtManager.getUserMappingFile());
41+
}
42+
43+
/**
44+
* If the JWT token is valid and has the right claims, it returns the principal,
45+
* otherwise is returns an empty optional.
46+
*
47+
* @param token JWT token
48+
* @return an optional principal
49+
*/
50+
@Override
51+
public Optional<LbPrincipal> authenticate(String token)
52+
throws AuthenticationException
53+
{
54+
Optional<Map<String, Claim>> claimsOptional = jwtManager.getClaimsFromToken(token);
55+
if (claimsOptional.isEmpty()) {
56+
log.error("JWT token verification failed");
57+
throw new AuthenticationException("JWT token verification failed");
58+
}
59+
60+
Map<String, Claim> claims = claimsOptional.orElseThrow();
61+
62+
String principalField = jwtManager.getPrincipalField();
63+
if (!claims.containsKey(principalField)) {
64+
log.error("Required principal field %s not found in JWT token", principalField);
65+
throw new AuthenticationException("Principal field does not exist in JWT token");
66+
}
67+
68+
String originalUserId = claims.get(principalField).asString();
69+
// Apply user mapping to transform the username if configured
70+
String mappedUserId = userMapping.mapUser(originalUserId);
71+
if (log.isDebugEnabled()) {
72+
log.debug("JWT principal mapping: %s -> %s", originalUserId, mappedUserId);
73+
}
74+
75+
// Get privileges from the authorization manager
76+
Optional<String> privileges = authorizationManager.getPrivileges(mappedUserId);
77+
78+
return Optional.of(new LbPrincipal(mappedUserId, privileges));
79+
}
80+
}

0 commit comments

Comments
 (0)