diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java new file mode 100644 index 000000000000..ea491fee1351 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProvider.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.inject.Inject; +import io.trino.cache.EvictableCacheBuilder; +import io.trino.plugin.base.group.CachingGroupProviderModule.ForCachingGroupProvider; +import io.trino.spi.security.GroupProvider; + +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class CachingGroupProvider + implements GroupProvider, GroupCacheInvalidationController +{ + private final LoadingCache> cache; + + @Inject + public CachingGroupProvider(CachingGroupProviderConfig config, @ForCachingGroupProvider GroupProvider delegate) + { + requireNonNull(delegate, "delegate is null"); + this.cache = EvictableCacheBuilder.newBuilder() + .maximumSize(config.getCacheMaximumSize()) + .expireAfterWrite(config.getTtl().toMillis(), MILLISECONDS) + .shareNothingWhenDisabled() + .build(CacheLoader.from(delegate::getGroups)); + } + + @Override + public Set getGroups(String user) + { + return cache.getUnchecked(user); + } + + @Override + public void invalidate(String user) + { + cache.invalidate(user); + } + + @Override + public void invalidateAll() + { + cache.invalidateAll(); + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java new file mode 100644 index 000000000000..f07e22d0840f --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderConfig.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.units.Duration; +import jakarta.validation.constraints.Min; + +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class CachingGroupProviderConfig +{ + private Duration ttl = new Duration(5, SECONDS); + private long cacheMaximumSize = Long.MAX_VALUE; + + public Duration getTtl() + { + return ttl; + } + + @Config("cache.ttl") + @ConfigDescription("Determines how long group information will be cached for each user") + public CachingGroupProviderConfig setTtl(Duration ttl) + { + this.ttl = requireNonNull(ttl, "ttl is null"); + return this; + } + + @Min(1) + public long getCacheMaximumSize() + { + return cacheMaximumSize; + } + + @Config("cache.maximum-size") + @ConfigDescription("Maximum number of users for which groups are stored in the cache") + public CachingGroupProviderConfig setCacheMaximumSize(long cacheMaximumSize) + { + this.cacheMaximumSize = cacheMaximumSize; + return this; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java new file mode 100644 index 000000000000..542207ff7ad1 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/CachingGroupProviderModule.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.inject.Binder; +import com.google.inject.BindingAnnotation; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.spi.security.GroupProvider; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Optional; + +import static io.airlift.configuration.ConditionalModule.conditionalModule; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static java.util.Objects.requireNonNull; + +/** + * If added to the list of {@link com.google.inject.Module}s used in initialization of a Guice context in a + * {@link io.trino.spi.security.GroupProviderFactory}, it will (almost) automatically add caching capability to the + * group provider. Requirements: + * + * The module will make the following configuration options available (to be set in {@code etc/group-provider.properties}: + * + * These properties can optionally have an arbitrary prefix ({@link Builder#withPrefix(String)}) + * and/or a binding annotation for the resulting binding of {@link GroupProvider} ({@link Builder#withBindingAnnotation(Class)}). + *

+ * An additional object of type {@link GroupCacheInvalidationController} will also be bound, with which one can invalidate + * all or part of the cache. + */ +public class CachingGroupProviderModule + extends AbstractConfigurationAwareModule +{ + private final Optional prefix; + private final Optional> bindingAnnotation; + + private CachingGroupProviderModule(Optional prefix, Optional> bindingAnnotation) + { + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(GroupProviderConfig.class, prefix.orElse(null)); + prefix.ifPresentOrElse( + prefix -> install(conditionalModule( + GroupProviderConfig.class, + prefix, + GroupProviderConfig::isCachingEnabled, + new CacheModule(Optional.of(prefix), bindingAnnotation), + new NonCacheModule(bindingAnnotation))), + () -> install(conditionalModule( + GroupProviderConfig.class, + GroupProviderConfig::isCachingEnabled, + new CacheModule(Optional.empty(), bindingAnnotation), + new NonCacheModule(bindingAnnotation)))); + } + + private static class CacheModule + implements Module + { + private final Optional prefix; + private final Optional> bindingAnnotation; + + public CacheModule(Optional prefix, Optional> bindingAnnotation) + { + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(CachingGroupProviderConfig.class, prefix.orElse(null)); + binder.bind(CachingGroupProvider.class).in(Scopes.SINGLETON); + binder.bind(bindingAnnotation + .map(bindingAnnotation -> Key.get(GroupProvider.class, bindingAnnotation)) + .orElseGet(() -> Key.get(GroupProvider.class))) + .to(CachingGroupProvider.class) + .in(Scopes.SINGLETON); + binder.bind(GroupCacheInvalidationController.class) + .to(CachingGroupProvider.class) + .in(Scopes.SINGLETON); + } + } + + private static class NonCacheModule + implements Module + { + private final Optional> bindingAnnotation; + + public NonCacheModule(Optional> bindingAnnotation) + { + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public void configure(Binder binder) + { + binder.bind(bindingAnnotation + .map(bindingAnnotation -> Key.get(GroupProvider.class, bindingAnnotation)) + .orElseGet(() -> Key.get(GroupProvider.class))) + .to(Key.get(GroupProvider.class, ForCachingGroupProvider.class)) + .in(Scopes.SINGLETON); + binder.bind(GroupCacheInvalidationController.class) + .to(NoOpGroupCacheInvalidationController.class) + .in(Scopes.SINGLETON); + } + } + + @Retention(RUNTIME) + @Target({FIELD, PARAMETER, METHOD}) + @BindingAnnotation + public @interface ForCachingGroupProvider + { + } + + public static CachingGroupProviderModule create() + { + return builder().build(); + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private Optional prefix = Optional.empty(); + private Optional> bindingAnnotation = Optional.empty(); + + private Builder() {} + + @CanIgnoreReturnValue + public Builder withPrefix(String prefix) + { + this.prefix = Optional.of(prefix); + return this; + } + + @CanIgnoreReturnValue + public Builder withBindingAnnotation(Class bindingAnnotation) + { + this.bindingAnnotation = Optional.of(bindingAnnotation); + return this; + } + + public CachingGroupProviderModule build() + { + return new CachingGroupProviderModule(prefix, bindingAnnotation); + } + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java new file mode 100644 index 000000000000..886bbad1861b --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupCacheInvalidationController.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +public interface GroupCacheInvalidationController +{ + void invalidate(String user); + + void invalidateAll(); +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java new file mode 100644 index 000000000000..5a5569f6f4a3 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/GroupProviderConfig.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +public class GroupProviderConfig +{ + private boolean isCachingEnabled; + + public boolean isCachingEnabled() + { + return isCachingEnabled; + } + + @Config("cache.enabled") + @ConfigDescription("Enables caching for the group provider") + public GroupProviderConfig setCachingEnabled(boolean isCachingEnabled) + { + this.isCachingEnabled = isCachingEnabled; + return this; + } +} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java new file mode 100644 index 000000000000..ad3626f1b04a --- /dev/null +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/group/NoOpGroupCacheInvalidationController.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +public class NoOpGroupCacheInvalidationController + implements GroupCacheInvalidationController +{ + @Override + public void invalidate(String user) {} + + @Override + public void invalidateAll() {} +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java new file mode 100644 index 000000000000..66ed4c770b30 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProvider.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.inject.BindingAnnotation; +import com.google.inject.Inject; +import com.google.inject.Injector; +import com.google.inject.Key; +import io.airlift.bootstrap.Bootstrap; +import io.trino.plugin.base.group.CachingGroupProviderModule.ForCachingGroupProvider; +import io.trino.spi.security.GroupProvider; +import io.trino.spi.security.GroupProviderFactory; +import org.junit.jupiter.api.Test; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestCachingGroupProvider +{ + @Test + public void testWithOutCaching() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.empty()).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.of(ForTesting.class)).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithPrefix() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.empty()).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithOutCachingWithPrefixWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "false"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.of(ForTesting.class)).create(properties); + + innerTestWithOutCaching(countingGroupProvider, groupProvider); + } + + private static void innerTestWithOutCaching(CountingGroupProvider countingGroupProvider, TestingGroupProvider groupProvider) + { + assertThat(countingGroupProvider.getCount()).isEqualTo(0); + + // first batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(1); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // second batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(4); + + // invalidate user + groupProvider.invalidate("testUser1"); + // no effect: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(5); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(6); + + // invalidate all + groupProvider.invalidateAll(); + // no effect: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(7); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(8); + } + + @Test + public void testWithCaching() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "true", + "cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.empty()).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "cache.enabled", "true", + "cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.empty(), Optional.of(ForTesting.class)).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithPrefix() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "true", + "group-provider.cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.empty()).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + @Test + public void testWithCachingWithPrefixWithBindingAnnotation() + { + CountingGroupProvider countingGroupProvider = new CountingGroupProvider(); + Map properties = ImmutableMap.of( + "group-provider.cache.enabled", "true", + "group-provider.cache.ttl", "1 h"); + TestingGroupProvider groupProvider = new TestingGroupProviderFactory(countingGroupProvider, Optional.of("group-provider"), Optional.of(ForTesting.class)).create(properties); + + innerTestWithCaching(countingGroupProvider, groupProvider); + } + + private static void innerTestWithCaching(CountingGroupProvider countingGroupProvider, TestingGroupProvider groupProvider) + { + assertThat(countingGroupProvider.getCount()).isEqualTo(0); + + // first batch + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(1); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // second batch is handled by the cache so delegate not invoked + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(2); + + // invalidate user + groupProvider.invalidate("testUser1"); + // effect on testUser1 only: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(3); + + // invalidate all + groupProvider.invalidateAll(); + // effect on both: + assertThat(groupProvider.getGroups("testUser1")).containsOnly("test", "testUser1"); + assertThat(countingGroupProvider.getCount()).isEqualTo(4); + assertThat(groupProvider.getGroups("testUser2")).containsOnly("test", "testUser2"); + assertThat(countingGroupProvider.getCount()).isEqualTo(5); + } + + private static class CountingGroupProvider + implements GroupProvider + { + private final AtomicInteger counter = new AtomicInteger(0); + + @Override + public Set getGroups(String user) + { + counter.incrementAndGet(); + return ImmutableSet.of("test", user); + } + + public int getCount() + { + return counter.get(); + } + } + + private static class TestingGroupProviderFactory + implements GroupProviderFactory + { + private final CountingGroupProvider groupProvider; + private final Optional prefix; + private final Optional> bindingAnnotation; + + private TestingGroupProviderFactory(CountingGroupProvider groupProvider, Optional prefix, Optional> bindingAnnotation) + { + this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); + this.prefix = requireNonNull(prefix, "prefix is null"); + this.bindingAnnotation = requireNonNull(bindingAnnotation, "bindingAnnotation is null"); + } + + @Override + public String getName() + { + return "counting"; + } + + @Override + public TestingGroupProvider create(Map config) + { + CachingGroupProviderModule.Builder moduleBuilder = CachingGroupProviderModule.builder(); + prefix.ifPresent(moduleBuilder::withPrefix); + bindingAnnotation.ifPresent(moduleBuilder::withBindingAnnotation); + + Bootstrap app = new Bootstrap( + moduleBuilder.build(), + binder -> { + binder.bind(Key.get(GroupProvider.class, ForCachingGroupProvider.class)) + .toInstance(groupProvider); + bindingAnnotation.ifPresent(bindingAnnotation -> + binder.bind(GroupProvider.class).to(Key.get(GroupProvider.class, bindingAnnotation))); + binder.bind(TestingGroupProvider.class); + }); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(TestingGroupProvider.class); + } + } + + private static class TestingGroupProvider + implements GroupProvider, GroupCacheInvalidationController + { + private final GroupProvider delegate; + private final GroupCacheInvalidationController invalidationController; + + @Inject + public TestingGroupProvider(GroupProvider delegate, GroupCacheInvalidationController invalidationController) + { + this.delegate = requireNonNull(delegate, "delegate"); + this.invalidationController = requireNonNull(invalidationController, "invalidationController is null"); + } + + @Override + public Set getGroups(String user) + { + return delegate.getGroups(user); + } + + @Override + public void invalidate(String user) + { + invalidationController.invalidate(user); + } + + @Override + public void invalidateAll() + { + invalidationController.invalidateAll(); + } + } + + @Retention(RUNTIME) + @Target({FIELD, PARAMETER, METHOD}) + @BindingAnnotation + private @interface ForTesting + { + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java new file mode 100644 index 000000000000..ee80de362f48 --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestCachingGroupProviderConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestCachingGroupProviderConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(CachingGroupProviderConfig.class) + .setTtl(new Duration(5, SECONDS)) + .setCacheMaximumSize(Long.MAX_VALUE)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.of( + "cache.ttl", "10 s", + "cache.maximum-size", "10"); + + CachingGroupProviderConfig expected = new CachingGroupProviderConfig() + .setTtl(new Duration(10, SECONDS)) + .setCacheMaximumSize(10); + + assertFullMapping(properties, expected); + } +} diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java new file mode 100644 index 000000000000..5621337643dd --- /dev/null +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/group/TestGroupProviderConfig.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.base.group; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestGroupProviderConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(GroupProviderConfig.class) + .setCachingEnabled(false)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.of("cache.enabled", "true"); + + GroupProviderConfig expected = new GroupProviderConfig() + .setCachingEnabled(true); + + assertFullMapping(properties, expected); + } +} diff --git a/plugin/trino-ldap-group-provider/pom.xml b/plugin/trino-ldap-group-provider/pom.xml index 5898c39b9250..9658b1718988 100644 --- a/plugin/trino-ldap-group-provider/pom.xml +++ b/plugin/trino-ldap-group-provider/pom.xml @@ -37,6 +37,12 @@ configuration + + io.airlift + junit-extensions + 2 + + io.airlift log @@ -117,7 +123,7 @@ org.junit.jupiter junit-jupiter-api - test + compile diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java index 07cc25db23bb..1589d8eab55f 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapFilteringGroupProvider.java @@ -91,7 +91,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] failed", user, e); + log.error(e, "LDAP search for user [%s] failed", user); return ImmutableSet.of(); } @@ -124,7 +124,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] groups failed", user, e); + log.error(e, "LDAP search for user [%s] groups failed", user); return ImmutableSet.of(); } }).orElse(ImmutableSet.of()); diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java index c77e9d402417..718dd1522d7f 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapGroupProviderFactory.java @@ -15,6 +15,7 @@ import com.google.inject.Injector; import io.airlift.bootstrap.Bootstrap; +import io.trino.plugin.base.group.CachingGroupProviderModule; import io.trino.plugin.base.ldap.LdapClientModule; import io.trino.spi.security.GroupProvider; import io.trino.spi.security.GroupProviderFactory; @@ -38,6 +39,7 @@ public GroupProvider create(Map requiredConfig) requireNonNull(requiredConfig, "config is null"); Bootstrap app = new Bootstrap( + CachingGroupProviderModule.create(), new LdapClientModule(), new LdapGroupProviderModule()); diff --git a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java index f71820a0e112..f389a71702cc 100644 --- a/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java +++ b/plugin/trino-ldap-group-provider/src/main/java/io/trino/plugin/ldapgroup/LdapSingleQueryGroupProvider.java @@ -99,7 +99,7 @@ public Set getGroups(String user) }); } catch (NamingException e) { - log.error("LDAP search for user [%s] failed", user, e); + log.error(e, "LDAP search for user [%s] failed", user); return ImmutableSet.of(); } } diff --git a/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java b/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java index b5253c3175a9..7a27076ad25a 100644 --- a/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java +++ b/plugin/trino-ldap-group-provider/src/test/java/io/trino/plugin/ldapgroup/TestLdapGroupProviderIntegration.java @@ -13,10 +13,8 @@ */ package io.trino.plugin.ldapgroup; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ObjectArrays; import com.google.common.io.Closer; import io.trino.spi.security.GroupProvider; import io.trino.testing.containers.TestingOpenLdapServer; @@ -26,21 +24,15 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import org.testcontainers.containers.Network; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.stream.Stream; -import static com.google.common.collect.ImmutableList.toImmutableList; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -82,7 +74,8 @@ public class TestLdapGroupProviderIntegration private DisposableSubContext qa; @BeforeAll - public void setup() throws Exception + public void setup() + throws Exception { closer = Closer.create(); Network network = Network.newNetwork(); @@ -224,7 +217,8 @@ public void testGetGroupsWithBadGroupMemberAttributeReturnsEmpty() } @Test - public void testGetGroupsWithBadGroupNameReturnsFullName() { + public void testGetGroupsWithBadGroupNameReturnsFullName() + { assertGetGroupsWithBadGroupNameReturnsFullName(cacheDisabledWithMemberOf); assertGetGroupsWithBadGroupNameReturnsFullName(cacheDisabledWithGroupFilter); assertGetGroupsWithBadGroupNameReturnsFullName(cacheEnabledWithMemberOf); @@ -246,7 +240,9 @@ private void assertGetGroupsWithBadGroupNameReturnsFullName(ConfigBuilder config } @Test - public void testGetGroupsConcurrently() throws InterruptedException { + public void testGetGroupsConcurrently() + throws InterruptedException + { assertGetGroupsConcurrently(cacheDisabledWithMemberOf); assertGetGroupsConcurrently(cacheDisabledWithGroupFilter); assertGetGroupsConcurrently(cacheEnabledWithMemberOf);