diff --git a/docs/security.md b/docs/security.md index 04299612b..e0f642ec5 100644 --- a/docs/security.md +++ b/docs/security.md @@ -36,8 +36,8 @@ for more details. The authentication would happen on https protocol only. Add the `authentication:` section in the config file. The default authentication type is -set using `defaultType: "form"` Following types of the authentications are -supported. +set using `defaultTypes: ["form"]`. The first authentication type in `defaultTypes` is prioritized and then falls back to following ones. +Following types of the authentications are supported. ### OAuth/OpenIDConnect @@ -45,7 +45,7 @@ It can be configured as below ```yaml authentication: - defaultType: "oauth" + defaultTypes: ["oauth"] oauth: issuer: clientId: @@ -81,6 +81,11 @@ Set the `privilegesField` to retrieve privileges from an OAuth claim. ``` - That also means you need to have a cluster with that routing group. - It's ok to replicate an existing Trino cluster record with a different name for that purpose. +- If you want to have all users who are authenticated via SSO and are not in the `presetUsers` list be able to view the dashboard and query history, you can set `defaultPrivilege` in the config file: +```yaml +authorization: + defaultPrivilege: "USER" +``` ### Form/Basic authentication @@ -102,7 +107,7 @@ Also provide a random key pair in RSA format. ```yaml authentication: - defaultType: "form" + defaultTypes: ["form"] form: selfSignKeyPair: privateKeyRsa: @@ -115,7 +120,7 @@ LDAP requires both random key pair and config path for LDAP ```yaml authentication: - defaultType: "form" + defaultTypes: ["form"] form: ldapConfigPath: selfSignKeyPair: diff --git a/gateway-ha/pom.xml b/gateway-ha/pom.xml index 9fe50911c..da9f71ba4 100644 --- a/gateway-ha/pom.xml +++ b/gateway-ha/pom.xml @@ -323,6 +323,13 @@ runtime + + com.github.docker-java + docker-java-api + 3.4.2 + test + + com.h2database diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthenticationConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthenticationConfiguration.java index f0f041588..5188805b9 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthenticationConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthenticationConfiguration.java @@ -13,29 +13,31 @@ */ package io.trino.gateway.ha.config; +import java.util.List; + public class AuthenticationConfiguration { - private String defaultType; + private List defaultTypes; private OAuthConfiguration oauth; private FormAuthConfiguration form; - public AuthenticationConfiguration(String defaultType, OAuthConfiguration oauth, FormAuthConfiguration form) + public AuthenticationConfiguration(List defaultTypes, OAuthConfiguration oauth, FormAuthConfiguration form) { - this.defaultType = defaultType; + this.defaultTypes = defaultTypes; this.oauth = oauth; this.form = form; } public AuthenticationConfiguration() {} - public String getDefaultType() + public List getDefaultTypes() { - return this.defaultType; + return this.defaultTypes; } - public void setDefaultType(String defaultType) + public void setDefaultTypes(List defaultTypes) { - this.defaultType = defaultType; + this.defaultTypes = defaultTypes; } public OAuthConfiguration getOauth() diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthorizationConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthorizationConfiguration.java index 2b3e85fc3..955e47ea8 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthorizationConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/AuthorizationConfiguration.java @@ -19,13 +19,15 @@ public class AuthorizationConfiguration private String user; private String api; private String ldapConfigPath; + private String defaultPrivilege; - public AuthorizationConfiguration(String admin, String user, String api, String ldapConfigPath) + public AuthorizationConfiguration(String admin, String user, String api, String ldapConfigPath, String defaultPrivilege) { this.admin = admin; this.user = user; this.api = api; this.ldapConfigPath = ldapConfigPath; + this.defaultPrivilege = defaultPrivilege; } public AuthorizationConfiguration() {} @@ -69,4 +71,14 @@ public void setLdapConfigPath(String ldapConfigPath) { this.ldapConfigPath = ldapConfigPath; } + + public String getDefaultPrivilege() + { + return this.defaultPrivilege; + } + + public void setDefaultPrivilege(String defaultPrivilege) + { + this.defaultPrivilege = defaultPrivilege; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index 91594f9f1..d930f9f07 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -145,28 +145,39 @@ private LbFormAuthManager getFormAuthManager(HaGatewayConfiguration configuratio private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration config, Authorizer authorizer) { ImmutableList.Builder authFilters = ImmutableList.builder(); - String defaultType = config.getDefaultType(); - if (oauthManager != null) { - authFilters.add(new LbFilter( - new LbAuthenticator(oauthManager, authorizationManager), - authorizer, - "Bearer", - new LbUnauthorizedHandler(defaultType))); + List authMethods = config.getDefaultTypes(); + if (authMethods == null || authMethods.isEmpty()) { + return new ChainedAuthFilter(authFilters.build()); } - if (formAuthManager != null) { - authFilters.add(new LbFilter( - new FormAuthenticator(formAuthManager, authorizationManager), - authorizer, - "Bearer", - new LbUnauthorizedHandler(defaultType))); + for (String authMethod : authMethods) { + switch (authMethod) { + case "oauth" -> { + if (oauthManager != null) { + authFilters.add(new LbFilter( + new LbAuthenticator(oauthManager, authorizationManager), + authorizer, + "Bearer", + new LbUnauthorizedHandler("oauth"))); + } + } + case "form" -> { + if (formAuthManager != null) { + authFilters.add(new LbFilter( + new FormAuthenticator(formAuthManager, authorizationManager), + authorizer, + "Bearer", + new LbUnauthorizedHandler("form"))); - authFilters.add(new BasicAuthFilter( - new ApiAuthenticator(formAuthManager, authorizationManager), - authorizer, - new LbUnauthorizedHandler(defaultType))); + authFilters.add(new BasicAuthFilter( + new ApiAuthenticator(formAuthManager, authorizationManager), + authorizer, + new LbUnauthorizedHandler("form"))); + } + } + default -> {} + } } - return new ChainedAuthFilter(authFilters.build()); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java index a5a89cbd4..7d4a40934 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/resource/LoginResource.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.inject.Inject; import io.airlift.log.Logger; +import io.trino.gateway.ha.config.HaGatewayConfiguration; import io.trino.gateway.ha.domain.Result; import io.trino.gateway.ha.domain.request.RestLoginRequest; import io.trino.gateway.ha.security.LbFormAuthManager; @@ -56,10 +57,12 @@ public class LoginResource private final LbOAuthManager oauthManager; private final LbFormAuthManager formAuthManager; + private final HaGatewayConfiguration haGatewayConfiguration; @Inject - public LoginResource(@Nullable LbOAuthManager oauthManager, @Nullable LbFormAuthManager formAuthManager) + public LoginResource(HaGatewayConfiguration haGatewayConfiguration, @Nullable LbOAuthManager oauthManager, @Nullable LbFormAuthManager formAuthManager) { + this.haGatewayConfiguration = haGatewayConfiguration; this.oauthManager = oauthManager; this.formAuthManager = formAuthManager; } @@ -173,16 +176,13 @@ else if (oauthManager != null) { @Produces(MediaType.APPLICATION_JSON) public Response loginType() { - String loginType; - if (formAuthManager != null) { - loginType = "form"; - } - else if (oauthManager != null) { - loginType = "oauth"; + List loginTypes; + if (haGatewayConfiguration.getAuthentication() != null) { + loginTypes = haGatewayConfiguration.getAuthentication().getDefaultTypes(); } else { - loginType = "none"; + loginTypes = List.of("none"); } - return Response.ok(Result.ok("Ok", loginType)).build(); + return Response.ok(Result.ok("Ok", loginTypes)).build(); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java index ae89c862a..137f9a26e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/security/AuthorizationManager.java @@ -24,11 +24,13 @@ public class AuthorizationManager { private final Map presetUsers; private final LbLdapClient lbLdapClient; + private final AuthorizationConfiguration authorizationConfiguration; public AuthorizationManager(AuthorizationConfiguration configuration, Map presetUsers) { this.presetUsers = presetUsers; + this.authorizationConfiguration = configuration; if (configuration != null && configuration.getLdapConfigPath() != null) { lbLdapClient = new LbLdapClient(LdapConfiguration.load(configuration.getLdapConfigPath())); } @@ -49,6 +51,13 @@ public Optional getPrivileges(String username) else if (lbLdapClient != null) { privs = lbLdapClient.getMemberOf(username); } - return Optional.ofNullable(privs); + + if (privs == null || privs.trim().isEmpty()) { + if (authorizationConfiguration != null && authorizationConfiguration.getDefaultPrivilege() != null) { + return Optional.of(authorizationConfiguration.getDefaultPrivilege()); + } + return Optional.empty(); // No default privilege if not configured + } + return Optional.of(privs); } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java index 051644436..2bef86bcf 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java @@ -18,6 +18,8 @@ import io.airlift.log.Logger; import io.trino.gateway.ha.clustermonitor.ClusterStats; import io.trino.gateway.ha.clustermonitor.TrinoStatus; +import okhttp3.CookieJar; +import okhttp3.JavaNetCookieJar; import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -28,16 +30,34 @@ import okhttp3.mockwebserver.MockWebServer; import org.jdbi.v3.core.Handle; import org.jdbi.v3.core.Jdbi; +import org.testcontainers.containers.FixedHostPortGenericContainer; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.containers.Network; import org.testcontainers.containers.OracleContainer; import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.containers.TrinoContainer; +import org.testcontainers.containers.startupcheck.OneShotStartupCheckStrategy; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.containers.wait.strategy.WaitAllStrategy; +import org.testcontainers.utility.DockerImageName; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; import java.io.BufferedWriter; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.net.CookieManager; +import java.net.CookiePolicy; import java.nio.file.Files; import java.nio.file.Path; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; import java.util.Map; +import java.util.Optional; import java.util.Scanner; import java.util.concurrent.TimeUnit; @@ -46,17 +66,28 @@ import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.trino.gateway.ha.util.ConfigurationUtils.replaceEnvironmentVariables; +import static io.trino.gateway.ha.util.TestcontainersUtils.createPostgreSqlContainer; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.testcontainers.utility.MountableFile.forClasspathResource; public class HaGatewayTestUtils { private static final Logger log = Logger.get(HaGatewayTestUtils.class); private static final OkHttpClient httpClient = new OkHttpClient(); + // TODO: TRINO-7103 Set up TMC job to update Trino image version in trino-gateway + private static final String TRINO_IMAGE_VERSION = "406.186.12"; + + private static final String HYDRA_IMAGE = "oryd/hydra:v1.11.10"; + private static final String DSN = "mysql://hydra:mysecretpassword@tcp(hydra-db:3306)/hydra?parseTime=true"; + + private static final int TTL_ACCESS_TOKEN_IN_SECONDS = 5; + private static final int TTL_REFRESH_TOKEN_IN_SECONDS = 15; + private HaGatewayTestUtils() {} public static void prepareMockBackend( @@ -201,4 +232,151 @@ private static void verifyTrinoStatus(int port, String name) } throw new IllegalStateException("Trino cluster is not healthy"); } + + public static TrinoContainer getTrinoContainer() + { + TrinoContainer trino = new TrinoContainer(DockerImageName.parse("trinodb/trino:" + TRINO_IMAGE_VERSION)); + trino.withCopyFileToContainer(forClasspathResource("trino_container_init.sh"), "/"); + trino.withCreateContainerCmdModifier(cmd -> cmd.withEntrypoint("bash", "-c", "./trino_container_init.sh")); + return trino; + } + + public static void setupOidc(int routerPort, String configFile, String scopes) + throws Exception + { + PostgreSQLContainer postgresql = createPostgreSqlContainer(); + postgresql.start(); + setupOidcContainers(routerPort, scopes); + + Path localhostJks = Path.of("src", "test", "resources", "auth", "localhost.jks").toAbsolutePath(); + Map additionalVars = ImmutableMap.builder() + .put("REQUEST_ROUTER_PORT", String.valueOf(routerPort)) + .put("LOCALHOST_JKS", localhostJks.toString()) + .putAll(buildPostgresVars(postgresql)) + .buildOrThrow(); + File testConfigFile = HaGatewayTestUtils.buildGatewayConfig(configFile, additionalVars); + String[] args = {testConfigFile.getAbsolutePath()}; + HaGatewayLauncher.main(args); + } + + public static void setupOidcContainers(int routerPort, String scopes) + { + Network network = Network.newNetwork(); + MySQLContainer databaseContainer = new MySQLContainer<>("mysql:8.0.36") + .withNetwork(network) + .withNetworkAliases("hydra-db") + .withUsername("hydra") + .withPassword("mysecretpassword") + .withDatabaseName("hydra") + .waitingFor(Wait.forLogMessage(".*ready to accept connections.*", 1)); + databaseContainer.start(); + + GenericContainer migrationContainer = new GenericContainer<>(HYDRA_IMAGE) + .withNetwork(network) + .withCommand("migrate", "sql", "--yes", DSN) + .dependsOn(databaseContainer) + .withStartupCheckStrategy(new OneShotStartupCheckStrategy()); + migrationContainer.start(); + + FixedHostPortGenericContainer hydraConsent = new FixedHostPortGenericContainer<>("python:3.10.1-alpine") + .withFixedExposedPort(3000, 3000) + .withNetwork(network) + .withNetworkAliases("hydra-consent") + .withExposedPorts(3000) + .withCopyFileToContainer(forClasspathResource("auth/login_and_consent_server.py"), "/") + .withCommand("python", "/login_and_consent_server.py") + .waitingFor(Wait.forHttp("/healthz").forPort(3000).forStatusCode(200)); + hydraConsent.start(); + + FixedHostPortGenericContainer hydra = new FixedHostPortGenericContainer<>(HYDRA_IMAGE) + .withFixedExposedPort(4444, 4444) + .withFixedExposedPort(4445, 4445) + .withNetwork(network) + .withNetworkAliases("hydra") + .withEnv("LOG_LEVEL", "debug") + .withEnv("LOG_LEAK_SENSITIVE_VALUES", "true") + .withEnv("OAUTH2_EXPOSE_INTERNAL_ERRORS", "1") + .withEnv("GODEBUG", "http2debug=1") + .withEnv("DSN", DSN) + .withEnv("URLS_SELF_ISSUER", "http://localhost:4444/") + .withEnv("URLS_CONSENT", "http://localhost:3000/consent") + .withEnv("URLS_LOGIN", "http://localhost:3000/login") + .withEnv("STRATEGIES_ACCESS_TOKEN", "jwt") + .withEnv("TTL_ACCESS_TOKEN", TTL_ACCESS_TOKEN_IN_SECONDS + "s") + .withEnv("TTL_REFRESH_TOKEN", TTL_REFRESH_TOKEN_IN_SECONDS + "s") + .withEnv("OAUTH2_ALLOWED_TOP_LEVEL_CLAIMS", "groups") + .withCommand("serve", "all", "--dangerous-force-http") + .dependsOn(hydraConsent, migrationContainer) + .waitingFor(new WaitAllStrategy() + .withStrategy(Wait.forLogMessage(".*Setting up http server on :4444.*", 1)) + .withStrategy(Wait.forLogMessage(".*Setting up http server on :4445.*", 1))) + .withStartupTimeout(java.time.Duration.ofMinutes(3)); + hydra.start(); + + String clientId = "trino_client_id"; + String clientSecret = "trino_client_secret"; + String tokenEndpointAuthMethod = "client_secret_basic"; + String audience = "trino_client_id"; + String callbackUrl = format("https://localhost:%s/oidc/callback", routerPort); + GenericContainer clientCreatingContainer = new GenericContainer<>(HYDRA_IMAGE) + .withNetwork(network) + .dependsOn(hydra) + .withCommand("clients", "create", + "--endpoint", "http://hydra:4445", + "--skip-tls-verify", + "--id", clientId, + "--secret", clientSecret, + "--audience", audience, + "-g", "authorization_code,refresh_token,client_credentials", + "-r", "token,code,id_token", + "--scope", scopes, + "--token-endpoint-auth-method", tokenEndpointAuthMethod, + "--callbacks", callbackUrl); + clientCreatingContainer.start(); + } + + public static OkHttpClient createOkHttpClient(Optional cookieJar) + throws Exception + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() + .followRedirects(true) + .cookieJar(cookieJar.orElseGet(() -> { + CookieManager cookieManager = new CookieManager(); + cookieManager.setCookiePolicy(CookiePolicy.ACCEPT_ALL); + return new JavaNetCookieJar(cookieManager); + })); + setupInsecureSsl(httpClientBuilder); + return httpClientBuilder.build(); + } + + private static void setupInsecureSsl(OkHttpClient.Builder clientBuilder) + throws Exception + { + X509TrustManager trustAllCerts = new X509TrustManager() + { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + { + throw new UnsupportedOperationException("checkClientTrusted should not be called"); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + { + // skip validation of server certificate + } + + @Override + public X509Certificate[] getAcceptedIssuers() + { + return new X509Certificate[0]; + } + }; + + SSLContext sslContext = SSLContext.getInstance("SSL"); + sslContext.init(null, new TrustManager[] {trustAllCerts}, new SecureRandom()); + + clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustAllCerts); + clientBuilder.hostnameVerifier((hostname, session) -> true); + } } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthenticationFallbacks.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthenticationFallbacks.java new file mode 100644 index 000000000..bc3cfa702 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthenticationFallbacks.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.gateway.ha.HaGatewayTestUtils; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.Optional; + +import static io.trino.gateway.ha.HaGatewayTestUtils.createOkHttpClient; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class TestAuthenticationFallbacks +{ + private final int routerPort = 21001 + (int) (Math.random() * 1000); + + @BeforeAll + void setup() + throws Exception + { + HaGatewayTestUtils.setupOidc(routerPort, "auth/oauth-and-form-test-config.yml", "openid"); + } + + @Test + void testPrimaryAuth() + throws Exception + { + OkHttpClient httpClient = createOkHttpClient(Optional.empty()); + try (Response response = httpClient.newCall(uiLoginTypeCall().build()).execute()) { + String body = response.body().string(); + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode jsonNode = objectMapper.readTree(body); + JsonNode dataNode = jsonNode.get("data"); + assertThat(dataNode.isArray()).isTrue(); + assertThat(dataNode.size()).isEqualTo(2); + assertThat(dataNode.get(0).asText()).isEqualTo("oauth"); + assertThat(dataNode.get(1).asText()).isEqualTo("form"); + } + } + + @Test + void testApiFallbackAuth() + throws Exception + { + OkHttpClient httpClient = createOkHttpClient(Optional.empty()); + Request request = new Request.Builder() + .url(format("https://localhost:%s/webapp/getAllBackends", routerPort)) + .post(RequestBody.create("{}", MediaType.parse("application/json"))) + .addHeader("Authorization", "Basic YWRtaW4xOmFkbWluMV9wYXNzd29yZA==") // admin1:admin1_password + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + assertThat(response.isSuccessful()).isTrue(); + String body = response.body().string(); + assertThat(body).contains("Successful."); + } + } + + private Request.Builder uiLoginTypeCall() + { + return new Request.Builder() + .url(format("https://localhost:%s/loginType", routerPort)) + .post(RequestBody.create("{}", MediaType.parse("application/json"))); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthorizationManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthorizationManager.java new file mode 100644 index 000000000..3fd494cf3 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestAuthorizationManager.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.security; + +import io.trino.gateway.ha.config.AuthorizationConfiguration; +import io.trino.gateway.ha.config.UserConfiguration; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; + +public class TestAuthorizationManager +{ + @Test + public void testNoDefaultPrivilege() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + Map presetUsers = Collections.emptyMap(); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("newUser"); + assertThat(privileges).isEmpty(); + } + + @Test + public void testCustomDefaultPrivilege() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + authConfig.setDefaultPrivilege("CUSTOM_DEFAULT"); + Map presetUsers = Collections.emptyMap(); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("newUser"); + assertThat(privileges).isPresent(); + assertThat(privileges.orElseThrow()).isEqualTo("CUSTOM_DEFAULT"); + } + + @Test + public void testOauthDefaultUserRoleEnabled() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + authConfig.setDefaultPrivilege("USER"); + Map presetUsers = Collections.emptyMap(); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("newUser"); + assertThat(privileges).isPresent(); + assertThat(privileges.orElseThrow()).isEqualTo("USER"); + } + + @Test + public void testOauthDefaultUserRoleDisabled() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + Map presetUsers = Collections.emptyMap(); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("newUser"); + assertThat(privileges).isNotPresent(); + } + + @Test + public void testPresetUserRole() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + UserConfiguration presetUser = new UserConfiguration("ADMIN", "password"); + Map presetUsers = Map.of("adminUser", presetUser); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("adminUser"); + assertThat(privileges).isPresent(); + assertThat(privileges.orElseThrow()).isEqualTo("ADMIN"); + } + + @Test + public void testPresetUserWithEmptyPrivileges() + { + AuthorizationConfiguration authConfig = new AuthorizationConfiguration(); + UserConfiguration presetUser = new UserConfiguration("", "password"); // Empty privileges + Map presetUsers = Map.of("emptyPrivUser", presetUser); + AuthorizationManager authManager = new AuthorizationManager(authConfig, presetUsers); + + Optional privileges = authManager.getPrivileges("emptyPrivUser"); + assertThat(privileges).isEmpty(); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java index 2ba71405e..ef965df92 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/security/TestOIDC.java @@ -16,8 +16,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableMap; -import io.trino.gateway.ha.HaGatewayLauncher; import io.trino.gateway.ha.HaGatewayTestUtils; import okhttp3.Cookie; import okhttp3.CookieJar; @@ -30,135 +28,27 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; -import org.testcontainers.containers.FixedHostPortGenericContainer; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.Network; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.containers.startupcheck.OneShotStartupCheckStrategy; -import org.testcontainers.containers.wait.strategy.Wait; -import org.testcontainers.containers.wait.strategy.WaitAllStrategy; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509TrustManager; - -import java.io.File; import java.net.CookieManager; import java.net.CookiePolicy; -import java.net.URL; -import java.nio.file.Path; -import java.security.SecureRandom; -import java.security.cert.X509Certificate; import java.util.List; -import java.util.Map; import java.util.Optional; -import static io.trino.gateway.ha.HaGatewayTestUtils.buildPostgresVars; +import static io.trino.gateway.ha.HaGatewayTestUtils.createOkHttpClient; import static io.trino.gateway.ha.security.OidcCookie.OIDC_COOKIE; -import static io.trino.gateway.ha.util.TestcontainersUtils.createPostgreSqlContainer; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; -import static org.testcontainers.utility.MountableFile.forClasspathResource; @TestInstance(TestInstance.Lifecycle.PER_CLASS) final class TestOIDC { - private static final int TTL_ACCESS_TOKEN_IN_SECONDS = 5; - private static final int TTL_REFRESH_TOKEN_IN_SECONDS = 15; - - private static final String HYDRA_IMAGE = "oryd/hydra:v1.11.10"; - private static final String DSN = "postgres://hydra:mysecretpassword@hydra-db:5432/hydra?sslmode=disable"; private static final int ROUTER_PORT = 21001 + (int) (Math.random() * 1000); @BeforeAll void setup() throws Exception { - Network network = Network.newNetwork(); - - PostgreSQLContainer databaseContainer = createPostgreSqlContainer() - .withNetwork(network) - .withNetworkAliases("hydra-db") - .withUsername("hydra") - .withPassword("mysecretpassword") - .withDatabaseName("hydra"); - databaseContainer.start(); - - GenericContainer migrationContainer = new GenericContainer(HYDRA_IMAGE) - .withNetwork(network) - .withCommand("migrate", "sql", "--yes", DSN) - .dependsOn(databaseContainer) - .withStartupCheckStrategy(new OneShotStartupCheckStrategy()); - migrationContainer.start(); - - FixedHostPortGenericContainer hydraConsent = new FixedHostPortGenericContainer<>("python:3.10.1-alpine") - .withFixedExposedPort(3000, 3000) - .withNetwork(network) - .withNetworkAliases("hydra-consent") - .withExposedPorts(3000) - .withCopyFileToContainer(forClasspathResource("auth/login_and_consent_server.py"), "/") - .withCommand("python", "/login_and_consent_server.py") - .waitingFor(Wait.forHttp("/healthz").forPort(3000).forStatusCode(200)); - hydraConsent.start(); - - FixedHostPortGenericContainer hydra = new FixedHostPortGenericContainer<>(HYDRA_IMAGE) - .withFixedExposedPort(4444, 4444) - .withFixedExposedPort(4445, 4445) - .withNetwork(network) - .withNetworkAliases("hydra") - .withEnv("LOG_LEVEL", "debug") - .withEnv("LOG_LEAK_SENSITIVE_VALUES", "true") - .withEnv("OAUTH2_EXPOSE_INTERNAL_ERRORS", "1") - .withEnv("GODEBUG", "http2debug=1") - .withEnv("DSN", DSN) - .withEnv("URLS_SELF_ISSUER", "http://localhost:4444/") - .withEnv("URLS_CONSENT", "http://localhost:3000/consent") - .withEnv("URLS_LOGIN", "http://localhost:3000/login") - .withEnv("STRATEGIES_ACCESS_TOKEN", "jwt") - .withEnv("TTL_ACCESS_TOKEN", TTL_ACCESS_TOKEN_IN_SECONDS + "s") - .withEnv("TTL_REFRESH_TOKEN", TTL_REFRESH_TOKEN_IN_SECONDS + "s") - .withEnv("OAUTH2_ALLOWED_TOP_LEVEL_CLAIMS", "groups") - .withCommand("serve", "all", "--dangerous-force-http") - .dependsOn(hydraConsent, migrationContainer) - .waitingFor(new WaitAllStrategy() - .withStrategy(Wait.forLogMessage(".*Setting up http server on :4444.*", 1)) - .withStrategy(Wait.forLogMessage(".*Setting up http server on :4445.*", 1))) - .withStartupTimeout(java.time.Duration.ofMinutes(3)); - - String clientId = "trino_client_id"; - String clientSecret = "trino_client_secret"; - String tokenEndpointAuthMethod = "client_secret_basic"; - String audience = "trino_client_id"; - String callbackUrl = format("https://localhost:%s/oidc/callback", ROUTER_PORT); - GenericContainer clientCreatingContainer = new GenericContainer(HYDRA_IMAGE) - .withNetwork(network) - .dependsOn(hydra) - .withCommand("clients", "create", - "--endpoint", "http://hydra:4445", - "--skip-tls-verify", - "--id", clientId, - "--secret", clientSecret, - "--audience", audience, - "-g", "authorization_code,refresh_token,client_credentials", - "-r", "token,code,id_token", - "--scope", "openid,offline", - "--token-endpoint-auth-method", tokenEndpointAuthMethod, - "--callbacks", callbackUrl); - clientCreatingContainer.start(); - - PostgreSQLContainer gatewayBackendDatabase = createPostgreSqlContainer(); - gatewayBackendDatabase.start(); - - URL resource = HaGatewayTestUtils.class.getClassLoader().getResource("auth/localhost.jks"); - Map additionalVars = ImmutableMap.builder() - .put("REQUEST_ROUTER_PORT", String.valueOf(ROUTER_PORT)) - .put("LOCALHOST_JKS", Path.of(resource.toURI()).toString()) - .putAll(buildPostgresVars(gatewayBackendDatabase)) - .buildOrThrow(); - File testConfigFile = - HaGatewayTestUtils.buildGatewayConfig("auth/oauth-test-config.yml", additionalVars); - String[] args = {testConfigFile.getAbsolutePath()}; - HaGatewayLauncher.main(args); + HaGatewayTestUtils.setupOidc(ROUTER_PORT, "auth/oauth-test-config.yml", "openid,offline"); } @Test @@ -216,37 +106,6 @@ private Request.Builder uiCall() .post(RequestBody.create("", null)); } - public static void setupInsecureSsl(OkHttpClient.Builder clientBuilder) - throws Exception - { - X509TrustManager trustAllCerts = new X509TrustManager() - { - @Override - public void checkClientTrusted(X509Certificate[] chain, String authType) - { - throw new UnsupportedOperationException("checkClientTrusted should not be called"); - } - - @Override - public void checkServerTrusted(X509Certificate[] chain, String authType) - { - // skip validation of server certificate - } - - @Override - public X509Certificate[] getAcceptedIssuers() - { - return new X509Certificate[0]; - } - }; - - SSLContext sslContext = SSLContext.getInstance("SSL"); - sslContext.init(null, new TrustManager[] {trustAllCerts}, new SecureRandom()); - - clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustAllCerts); - clientBuilder.hostnameVerifier((hostname, session) -> true); - } - public static class BadCookieJar implements CookieJar { @@ -289,18 +148,4 @@ private static String extractRedirectURL(String body) JsonNode jsonNode = objectMapper.readTree(body); return jsonNode.get("data").asText(); } - - private static OkHttpClient createOkHttpClient(Optional cookieJar) - throws Exception - { - OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() - .followRedirects(true) - .cookieJar(cookieJar.orElseGet(() -> { - CookieManager cookieManager = new CookieManager(); - cookieManager.setCookiePolicy(CookiePolicy.ACCEPT_ALL); - return new JavaNetCookieJar(cookieManager); - })); - setupInsecureSsl(httpClientBuilder); - return httpClientBuilder.build(); - } } diff --git a/gateway-ha/src/test/resources/auth/auth-test-config.yml b/gateway-ha/src/test/resources/auth/auth-test-config.yml index f1ca251d5..a57aa1918 100644 --- a/gateway-ha/src/test/resources/auth/auth-test-config.yml +++ b/gateway-ha/src/test/resources/auth/auth-test-config.yml @@ -25,7 +25,7 @@ presetUsers: privileges: BAR_BAZ authentication: - defaultType: "form" + defaultTypes: ["form"] form: selfSignKeyPair: privateKeyRsa: src/test/resources/auth/test_private_key.pem diff --git a/gateway-ha/src/test/resources/auth/oauth-and-form-test-config.yml b/gateway-ha/src/test/resources/auth/oauth-and-form-test-config.yml new file mode 100644 index 000000000..b386c996c --- /dev/null +++ b/gateway-ha/src/test/resources/auth/oauth-and-form-test-config.yml @@ -0,0 +1,47 @@ +serverConfig: + node.environment: test + http-server.http.port: 8080 + http-server.https.enabled: true + http-server.https.port: ${ENV:REQUEST_ROUTER_PORT} + http-server.https.keystore.path: ${ENV:LOCALHOST_JKS} + http-server.https.keystore.key: 123456 + +dataStore: + jdbcUrl: ${ENV:POSTGRESQL_JDBC_URL} + user: ${ENV:POSTGRESQL_USER} + password: ${ENV:POSTGRESQL_PASSWORD} + driver: org.postgresql.Driver + +extraWhitelistPaths: + - '/v1/custom.*' + +authorization: + admin: .*FOO.* + user: .*BAR.* + api: .*BAZ.* + +presetUsers: + foo@bar.com: + privileges: FOO_BAR + admin1: + password: admin1_password + privileges: FOO_BAR + +authentication: + defaultTypes: ["oauth", "form"] + oauth: + issuer: http://localhost:4444/ + clientId: trino_client_id + clientSecret: trino_client_secret + tokenEndpoint: http://localhost:4444/oauth2/token + authorizationEndpoint: http://localhost:4444/oauth2/auth + jwkEndpoint: http://localhost:4444/.well-known/jwks.json + redirectUrl: https://localhost:${ENV:REQUEST_ROUTER_PORT}/oidc/callback + redirectWebUrl: https://localhost:${ENV:REQUEST_ROUTER_PORT}/ + userIdField: sub + scopes: + - openid + form: + selfSignKeyPair: + privateKeyRsa: src/test/resources/auth/test_private_key.pem + publicKeyRsa: src/test/resources/auth/test_public_key.pem diff --git a/gateway-ha/src/test/resources/auth/oauth-test-config.yml b/gateway-ha/src/test/resources/auth/oauth-test-config.yml index d25adc88f..ddaa691f7 100644 --- a/gateway-ha/src/test/resources/auth/oauth-test-config.yml +++ b/gateway-ha/src/test/resources/auth/oauth-test-config.yml @@ -25,7 +25,7 @@ presetUsers: privileges: FOO_BAR authentication: - defaultType: "oauth" + defaultTypes: ["oauth"] oauth: issuer: http://localhost:4444/ clientId: trino_client_id diff --git a/webapp/src/components/login.tsx b/webapp/src/components/login.tsx index 3499ebca7..4a4101f91 100644 --- a/webapp/src/components/login.tsx +++ b/webapp/src/components/login.tsx @@ -10,11 +10,13 @@ export function Login() { const access = useAccessStore(); const [formApi, setFormApi] = useState>(); const [loginBo, setLoginBo] = useState>({}); - const [loginType, setLoginType] = useState<'form' | 'oauth' | 'none'>(); + const [loginType, setLoginType] = useState(); useEffect(() => { loginTypeApi().then(data => { - setLoginType(data); + if (data && data.length > 0) { + setLoginType(data[0]); + } }).catch(() => { }); }, [])