From b43765d8ae9aae1c4317780bb539e8134de50ebc Mon Sep 17 00:00:00 2001 From: Krzysztof Sobolewski Date: Tue, 11 Jul 2023 16:48:28 +0200 Subject: [PATCH] Add caching group provider and enable for LDAP group provider This Guice Module can be used to enable caching in the group provider, by adding it to the list of modules in a Guice context in a group provider factory, or to any other Guice context as needed. Features: * Configurable configuration prefix * Ability to bind the final `GroupProvider` with a custom binding annotation * useful especially when the Guice context is not entirely isolated and there are other `GroupProvider` bindings in it * An `@Inject`-able hook for cache invalidation Author: Krzysztof Sobolewski Date: Tue Jul 11 16:48:28 2023 +0200 --- .../base/group/CachingGroupProvider.java | 61 ++++ .../group/CachingGroupProviderConfig.java | 55 ++++ .../group/CachingGroupProviderModule.java | 183 +++++++++++ .../GroupCacheInvalidationController.java | 21 ++ .../base/group/GroupProviderConfig.java | 35 ++ .../NoOpGroupCacheInvalidationController.java | 24 ++ .../base/group/TestCachingGroupProvider.java | 305 ++++++++++++++++++ .../group/TestCachingGroupProviderConfig.java | 50 +++ .../base/group/TestGroupProviderConfig.java | 44 +++ plugin/trino-ldap-group-provider/pom.xml | 8 +- .../ldapgroup/LdapFilteringGroupProvider.java | 4 +- .../ldapgroup/LdapGroupProviderFactory.java | 2 + .../LdapSingleQueryGroupProvider.java | 2 +- .../TestLdapGroupProviderIntegration.java | 18 +- 14 files changed, 797 insertions(+), 15 deletions(-) create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java create mode 100644 lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java create mode 100644 lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java create mode 100644 lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java create mode 100644 lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java new file mode 100644 index 000000000000..ea491fee1351 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.inject.Inject; +import io.trino.cache.EvictableCacheBuilder; +import io.trino.plugin.base.group.CachingGroupProviderModule.ForCachingGroupProvider; +import io.trino.spi.security.GroupProvider; + +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class CachingGroupProvider + implements GroupProvider, GroupCacheInvalidationController +{ + private final LoadingCache> cache; + + @Inject + public CachingGroupProvider(CachingGroupProviderConfig config, @ForCachingGroupProvider GroupProvider delegate) + { + requireNonNull(delegate, "delegate is null"); + this.cache = EvictableCacheBuilder.newBuilder() + .maximumSize(config.getCacheMaximumSize()) + .expireAfterWrite(config.getTtl().toMillis(), MILLISECONDS) + .shareNothingWhenDisabled() + .build(CacheLoader.from(delegate::getGroups)); + } + + @Override + public Set getGroups(String user) + { + return cache.getUnchecked(user); + } + + @Override + public void invalidate(String user) + { + cache.invalidate(user); + } + + @Override + public void invalidateAll() + { + cache.invalidateAll(); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java new file mode 100644 index 000000000000..f07e22d0840f --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.units.Duration; +import jakarta.validation.constraints.Min; + +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class CachingGroupProviderConfig +{ + private Duration ttl = new Duration(5, SECONDS); + private long cacheMaximumSize = Long.MAX_VALUE; + + public Duration getTtl() + { + return ttl; + } + + @Config("cache.ttl") + @ConfigDescription("Determines how long group information will be cached for each user") + public CachingGroupProviderConfig setTtl(Duration ttl) + { + this.ttl = requireNonNull(ttl, "ttl is null"); + return this; + } + + @Min(1) + public long getCacheMaximumSize() + { + return cacheMaximumSize; + } + + @Config("cache.maximum-size") + @ConfigDescription("Maximum number of users for which groups are stored in the cache") + public CachingGroupProviderConfig setCacheMaximumSize(long cacheMaximumSize) + { + this.cacheMaximumSize = cacheMaximumSize; + return this; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java new file mode 100644 index 000000000000..542207ff7ad1 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.inject.Binder; +import com.google.inject.BindingAnnotation; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.spi.security.GroupProvider; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Optional; + +import static io.airlift.configuration.ConditionalModule.conditionalModule; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static java.util.Objects.requireNonNull; + +/** + * If added to the list of {@link com.google.inject.Module}s used in initialization of a Guice context in a + * {@link io.trino.spi.security.GroupProviderFactory}, it will (almost) automatically add caching capability to the + * group provider. Requirements: + *
    + *
  • The {@link GroupProvider} available in the Guice context must be bound annotated with + * {@link ForCachingGroupProvider} binding annotation
  • + *
+ * The module will make the following configuration options available (to be set in {@code etc/group-provider.properties}: + *
    + *
  • {@code cache.enabled} - the toggle to enable or disable caching
  • + *
  • {@code cache.ttl} - determines how long group information will be cached for each user
  • + *
  • {@code cache.maximum-size} - maximum number of users for which groups are stored in the cache
  • + *
+ * These properties can optionally have an arbitrary prefix ({@link Builder#withPrefix(String)}) + * and/or a binding annotation for the resulting binding of {@link GroupProvider} ({@link Builder#withBindingAnnotation(Class)}). + *

+ * An additional object of type {@link GroupCacheInvalidationController} will also be bound, with which one can invalidate + * all or part of the cache. + */ +public class CachingGroupProviderModule + extends AbstractConfigurationAwareModule +{ + private final Optional prefix; + private final Optional> bindingAnnotation; + + private CachingGroupProviderModule(Optional prefix, Optional> bindingAnnotation) + { + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(GroupProviderConfig.class, prefix.orElse(null)); + prefix.ifPresentOrElse( + prefix -> install(conditionalModule( + GroupProviderConfig.class, + prefix, + GroupProviderConfig::isCachingEnabled, + new CacheModule(Optional.of(prefix), bindingAnnotation), + new NonCacheModule(bindingAnnotation))), + () -> install(conditionalModule( + GroupProviderConfig.class, + GroupProviderConfig::isCachingEnabled, + new CacheModule(Optional.empty(), bindingAnnotation), + new NonCacheModule(bindingAnnotation)))); + } + + private static class CacheModule + implements Module + { + private final Optional prefix; + private final Optional> bindingAnnotation; + + public CacheModule(Optional prefix, Optional> bindingAnnotation) + { + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(CachingGroupProviderConfig.class, prefix.orElse(null)); + binder.bind(CachingGroupProvider.class).in(Scopes.SINGLETON); + binder.bind(bindingAnnotation + .map(bindingAnnotation -> Key.get(GroupProvider.class, bindingAnnotation)) + .orElseGet(() -> Key.get(GroupProvider.class))) + .to(CachingGroupProvider.class) + .in(Scopes.SINGLETON); + binder.bind(GroupCacheInvalidationController.class) + .to(CachingGroupProvider.class) + .in(Scopes.SINGLETON); + } + } + + private static class NonCacheModule + implements Module + { + private final Optional> bindingAnnotation; + + public NonCacheModule(Optional> bindingAnnotation) + { + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public void configure(Binder binder) + { + binder.bind(bindingAnnotation + .map(bindingAnnotation -> Key.get(GroupProvider.class, bindingAnnotation)) + .orElseGet(() -> Key.get(GroupProvider.class))) + .to(Key.get(GroupProvider.class, ForCachingGroupProvider.class)) + .in(Scopes.SINGLETON); + binder.bind(GroupCacheInvalidationController.class) + .to(NoOpGroupCacheInvalidationController.class) + .in(Scopes.SINGLETON); + } + } + + @Retention(RUNTIME) + @Target({FIELD, PARAMETER, METHOD}) + @BindingAnnotation + public @interface ForCachingGroupProvider + { + } + + public static CachingGroupProviderModule create() + { + return builder().build(); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Optional prefix = Optional.empty(); + private Optional> bindingAnnotation = Optional.empty(); + + private Builder() {} + + @CanIgnoreReturnValue + public Builder withPrefix(String prefix) + { + this.prefix = Optional.of(prefix); + return this; + } + + @CanIgnoreReturnValue + public Builder withBindingAnnotation(Class bindingAnnotation) + { + this.bindingAnnotation = Optional.of(bindingAnnotation); + return this; + } + + public CachingGroupProviderModule build() + { + return new CachingGroupProviderModule(prefix, bindingAnnotation); + } + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java new file mode 100644 index 000000000000..886bbad1861b --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +public interface GroupCacheInvalidationController +{ + void invalidate(String user); + + void invalidateAll(); +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java new file mode 100644 index 000000000000..5a5569f6f4a3 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +public class GroupProviderConfig +{ + private boolean isCachingEnabled; + + public boolean isCachingEnabled() + { + return isCachingEnabled; + } + + @Config("cache.enabled") + @ConfigDescription("Enables caching for the group provider") + public GroupProviderConfig setCachingEnabled(boolean isCachingEnabled) + { + this.isCachingEnabled = isCachingEnabled; + return this; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java new file mode 100644 index 000000000000..ad3626f1b04a --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +public class NoOpGroupCacheInvalidationController + implements GroupCacheInvalidationController +{ + @Override + public void invalidate(String user) {} + + @Override + public void invalidateAll() {} +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java new file mode 100644 index 000000000000..66ed4c770b30 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.inject.BindingAnnotation; +import com.google.inject.Inject; +import com.google.inject.Injector; +import com.google.inject.Key; +import io.airlift.bootstrap.Bootstrap; +import io.trino.plugin.base.group.CachingGroupProviderModule.ForCachingGroupProvider; +import io.trino.spi.security.GroupProvider; +import io.trino.spi.security.GroupProviderFactory; +import org.junit.jupiter.api.Test; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCachingGroupProvider +{ + @Test + public void testWithOutCaching() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.empty()).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.of(ForTesting.class)).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithPrefix() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.empty()).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithPrefixWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.of(ForTesting.class)).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + private static void innerTestWithOutCaching(CountingGroupProvider countingGroupProvider, TestingGroupProvider groupProvider) + { + assertThat(countingGroupProvider.getCount()).isEqualTo(0); + + // first batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(1); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // second batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(4); + + // invalidate user + groupProvider.invalidate("testUser1"); + // no effect: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(5); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(6); + + // invalidate all + groupProvider.invalidateAll(); + // no effect: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(7); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(8); + } + + @Test + public void testWithCaching() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "true", + "cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.empty()).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "true", + "cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.of(ForTesting.class)).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithPrefix() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "true", + "group-provider.cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.empty()).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithPrefixWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "true", + "group-provider.cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.of(ForTesting.class)).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + private static void innerTestWithCaching(CountingGroupProvider countingGroupProvider, TestingGroupProvider groupProvider) + { + assertThat(countingGroupProvider.getCount()).isEqualTo(0); + + // first batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(1); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // second batch is handled by the cache so delegate not invoked + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // invalidate user + groupProvider.invalidate("testUser1"); + // effect on testUser1 only: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + + // invalidate all + groupProvider.invalidateAll(); + // effect on both: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(4); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(5); + } + + private static class CountingGroupProvider + implements GroupProvider + { + private final AtomicInteger counter = new AtomicInteger(0); + + @Override + public Set getGroups(String user) + { + counter.incrementAndGet(); + return ImmutableSet.of("test", user); + } + + public int getCount() + { + return counter.get(); + } + } + + private static class TestingGroupProviderFactory + implements GroupProviderFactory + { + private final CountingGroupProvider groupProvider; + private final Optional prefix; + private final Optional> bindingAnnotation; + + private TestingGroupProviderFactory(CountingGroupProvider groupProvider, Optional prefix, Optional> bindingAnnotation) + { + this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public String getName() + { + return "counting"; + } + + @Override + public TestingGroupProvider create(Map config) + { + CachingGroupProviderModule.Builder moduleBuilder = CachingGroupProviderModule.builder(); + prefix.ifPresent(moduleBuilder::withPrefix); + bindingAnnotation.ifPresent(moduleBuilder::withBindingAnnotation); + + Bootstrap app = new Bootstrap( + moduleBuilder.build(), + binder -> { + binder.bind(Key.get(GroupProvider.class, ForCachingGroupProvider.class)) + .toInstance(groupProvider); + bindingAnnotation.ifPresent(bindingAnnotation -> + binder.bind(GroupProvider.class).to(Key.get(GroupProvider.class, bindingAnnotation))); + binder.bind(TestingGroupProvider.class); + }); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(TestingGroupProvider.class); + } + } + + private static class TestingGroupProvider + implements GroupProvider, GroupCacheInvalidationController + { + private final GroupProvider delegate; + private final GroupCacheInvalidationController invalidationController; + + @Inject + public TestingGroupProvider(GroupProvider delegate, GroupCacheInvalidationController invalidationController) + { + this.delegate = requireNonNull(delegate, "delegate"); + this.invalidationController = requireNonNull(invalidationController, "invalidationController is null"); + } + + @Override + public Set getGroups(String user) + { + return delegate.getGroups(user); + } + + @Override + public void invalidate(String user) + { + invalidationController.invalidate(user); + } + + @Override + public void invalidateAll() + { + invalidationController.invalidateAll(); + } + } + + @Retention(RUNTIME) + @Target({FIELD, PARAMETER, METHOD}) + @BindingAnnotation + private @interface ForTesting + { + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java new file mode 100644 index 000000000000..ee80de362f48 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestCachingGroupProviderConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(CachingGroupProviderConfig.class) + .setTtl(new Duration(5, SECONDS)) + .setCacheMaximumSize(Long.MAX_VALUE)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.of( + "cache.ttl", "10 s", + "cache.maximum-size", "10"); + + CachingGroupProviderConfig expected = new CachingGroupProviderConfig() + .setTtl(new Duration(10, SECONDS)) + .setCacheMaximumSize(10); + + assertFullMapping(properties, expected); + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java new file mode 100644 index 000000000000..5621337643dd --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestGroupProviderConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(GroupProviderConfig.class) + .setCachingEnabled(false)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.of("cache.enabled", "true"); + + GroupProviderConfig expected = new GroupProviderConfig() + .setCachingEnabled(true); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-ldap-group-provider/pom.xml b/plugin/trino-ldap-group-provider/pom.xml index 5898c39b9250..9658b1718988 100644 --- a/plugin/trino-ldap-group-provider/pom.xml +++ b/plugin/trino-ldap-group-provider/pom.xml @@ -37,6 +37,12 @@ configuration + + io.airlift + junit-extensions + 2 + + io.airlift log @@ -117,7 +123,7 @@ org.junit.jupiter junit-jupiter-api - test + compile diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java index 07cc25db23bb..1589d8eab55f 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java @@ -91,7 +91,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] failed", user, e); + log.error(e, "LDAP search for user [%s] failed", user); return ImmutableSet.of(); } @@ -124,7 +124,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] groups failed", user, e); + log.error(e, "LDAP search for user [%s] groups failed", user); return ImmutableSet.of(); } }).orElse(ImmutableSet.of()); diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java index c77e9d402417..718dd1522d7f 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java @@ -15,6 +15,7 @@ import com.google.inject.Injector; import io.airlift.bootstrap.Bootstrap; +import io.trino.plugin.base.group.CachingGroupProviderModule; import io.trino.plugin.base.ldap.LdapClientModule; import io.trino.spi.security.GroupProvider; import io.trino.spi.security.GroupProviderFactory; @@ -38,6 +39,7 @@ public GroupProvider create(Map requiredConfig) requireNonNull(requiredConfig, "config is null"); Bootstrap app = new Bootstrap( + CachingGroupProviderModule.create(), new LdapClientModule(), new LdapGroupProviderModule()); diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java index f71820a0e112..f389a71702cc 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java @@ -99,7 +99,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] failed", user, e); + log.error(e, "LDAP search for user [%s] failed", user); return ImmutableSet.of(); } } diff --git a/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java b/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java index b5253c3175a9..7a27076ad25a 100644 --- a/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java +++ b/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java @@ -13,10 +13,8 @@ */ package io.trino.plugin.ldapgroup; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ObjectArrays; import com.google.common.io.Closer; import io.trino.spi.security.GroupProvider; import io.trino.testing.containers.TestingOpenLdapServer; @@ -26,21 +24,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import org.testcontainers.containers.Network; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.stream.Stream; -import static com.google.common.collect.ImmutableList.toImmutableList; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -82,7 +74,8 @@ public class TestLdapGroupProviderIntegration private DisposableSubContext qa; @BeforeAll - public void setup() throws Exception + public void setup() + throws Exception { closer = Closer.create(); Network network = Network.newNetwork(); @@ -224,7 +217,8 @@ public void testGetGroupsWithBadGroupMemberAttributeReturnsEmpty() } @Test - public void testGetGroupsWithBadGroupNameReturnsFullName() { + public void testGetGroupsWithBadGroupNameReturnsFullName() + { assertGetGroupsWithBadGroupNameReturnsFullName(cacheDisabledWithMemberOf); assertGetGroupsWithBadGroupNameReturnsFullName(cacheDisabledWithGroupFilter); assertGetGroupsWithBadGroupNameReturnsFullName(cacheEnabledWithMemberOf); @@ -246,7 +240,9 @@ private void assertGetGroupsWithBadGroupNameReturnsFullName(ConfigBuilder config } @Test - public void testGetGroupsConcurrently() throws InterruptedException { + public void testGetGroupsConcurrently() + throws InterruptedException + { assertGetGroupsConcurrently(cacheDisabledWithMemberOf); assertGetGroupsConcurrently(cacheDisabledWithGroupFilter); assertGetGroupsConcurrently(cacheEnabledWithMemberOf);