Skip to content

Commit 560a1f2

Browse files
committed
Enable advanced shard awareness
Makes `addMissingChannels` use advanced shard awareness. It will now specify target shard when adding missing channels for specific shards. In case returned channels do not match requested shards warnings are logged. Initial connection to the node works on previous rules, meaning it uses arbitrary local port for connection to arbitrary shard. Adds AdvancedShardAwarenessIT that has several methods displaying the difference between establishing connections with option enabled and disabled.
1 parent 1342eac commit 560a1f2

File tree

3 files changed

+325
-3
lines changed

3 files changed

+325
-3
lines changed

core/src/main/java/com/datastax/oss/driver/internal/core/pool/ChannelPool.java

+35-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import io.netty.util.concurrent.EventExecutor;
5959
import io.netty.util.concurrent.Future;
6060
import io.netty.util.concurrent.GenericFutureListener;
61+
import java.net.InetSocketAddress;
6162
import java.util.ArrayList;
6263
import java.util.Arrays;
6364
import java.util.HashSet;
@@ -489,9 +490,23 @@ private CompletionStage<Boolean> addMissingChannels() {
489490
channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum();
490491
LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing);
491492
DriverChannelOptions options = buildDriverOptions();
492-
for (int i = 0; i < missing; i++) {
493-
CompletionStage<DriverChannel> channelFuture = channelFactory.connect(node, options);
494-
pendingChannels.add(channelFuture);
493+
for (int shard = 0; shard < channels.length; shard++) {
494+
LOG.trace(
495+
"[{}] Missing {} channels for shard {}",
496+
logPrefix,
497+
wantedCount - channels[shard].size(),
498+
shard);
499+
for (int p = channels[shard].size(); p < wantedCount; p++) {
500+
CompletionStage<DriverChannel> channelFuture;
501+
if (config
502+
.getDefaultProfile()
503+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) {
504+
channelFuture = channelFactory.connect(node, shard, options);
505+
} else {
506+
channelFuture = channelFactory.connect(node, options);
507+
}
508+
pendingChannels.add(channelFuture);
509+
}
495510
}
496511
return CompletableFutures.allDone(pendingChannels)
497512
.thenApplyAsync(this::onAllConnected, adminExecutor);
@@ -551,6 +566,23 @@ private boolean onAllConnected(@SuppressWarnings("unused") Void v) {
551566
channel);
552567
channel.forceClose();
553568
} else {
569+
if (config
570+
.getDefaultProfile()
571+
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)
572+
&& channel.localAddress() instanceof InetSocketAddress
573+
&& channel.getShardingInfo() != null) {
574+
int port = ((InetSocketAddress) channel.localAddress()).getPort();
575+
int actualShard = channel.getShardId();
576+
int targetShard = port % channel.getShardingInfo().getShardsCount();
577+
if (actualShard != targetShard) {
578+
LOG.warn(
579+
"[{}] New channel {} connected to shard {}, but shard {} was requested. If this is not transient check your driver AND cluster configuration of shard aware port.",
580+
logPrefix,
581+
channel,
582+
actualShard,
583+
targetShard);
584+
}
585+
}
554586
LOG.debug("[{}] New channel added {}", logPrefix, channel);
555587
if (channels[channel.getShardId()].size() < wantedCount) {
556588
addChannel(channel);

core/src/test/java/com/datastax/oss/driver/internal/core/pool/ChannelPoolTestBase.java

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.mockito.Mockito.when;
2525

2626
import com.datastax.oss.driver.api.core.CqlIdentifier;
27+
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
2728
import com.datastax.oss.driver.api.core.config.DriverConfig;
2829
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
2930
import com.datastax.oss.driver.api.core.connection.ReconnectionPolicy;
@@ -77,6 +78,8 @@ public void setup() {
7778
when(nettyOptions.adminEventExecutorGroup()).thenReturn(adminEventLoopGroup);
7879
when(context.getConfig()).thenReturn(config);
7980
when(config.getDefaultProfile()).thenReturn(defaultProfile);
81+
when(defaultProfile.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED))
82+
.thenReturn(false);
8083
this.eventBus = spy(new EventBus("test"));
8184
when(context.getEventBus()).thenReturn(eventBus);
8285
when(context.getChannelFactory()).thenReturn(channelFactory);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
package com.datastax.oss.driver.core.pool;
2+
3+
import static junit.framework.TestCase.fail;
4+
5+
import ch.qos.logback.classic.Level;
6+
import ch.qos.logback.classic.Logger;
7+
import ch.qos.logback.classic.spi.ILoggingEvent;
8+
import ch.qos.logback.core.read.ListAppender;
9+
import com.datastax.oss.driver.api.core.CqlSession;
10+
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
11+
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
12+
import com.datastax.oss.driver.api.core.config.DriverConfigLoader;
13+
import com.datastax.oss.driver.api.core.session.Session;
14+
import com.datastax.oss.driver.api.testinfra.CassandraSkip;
15+
import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule;
16+
import com.datastax.oss.driver.api.testinfra.session.SessionUtils;
17+
import com.datastax.oss.driver.internal.core.pool.ChannelPool;
18+
import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures;
19+
import com.datastax.oss.driver.internal.core.util.concurrent.Reconnection;
20+
import com.google.common.collect.ImmutableMap;
21+
import com.google.common.collect.ImmutableSet;
22+
import com.google.common.util.concurrent.Uninterruptibles;
23+
import java.net.InetSocketAddress;
24+
import java.time.Duration;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Set;
28+
import java.util.concurrent.CompletionStage;
29+
import java.util.concurrent.TimeUnit;
30+
import java.util.regex.Pattern;
31+
import org.junit.After;
32+
import org.junit.Before;
33+
import org.junit.ClassRule;
34+
import org.junit.Test;
35+
import org.slf4j.LoggerFactory;
36+
37+
@CassandraSkip(description = "Advanced shard awareness relies on ScyllaDB's shard aware port")
38+
public class AdvancedShardAwarenessIT {
39+
40+
@ClassRule
41+
public static final CustomCcmRule CCM_RULE =
42+
CustomCcmRule.builder().withNodes(2).withJvmArgs("--smp=3").build();
43+
44+
public static ch.qos.logback.classic.Logger channelPoolLogger =
45+
(ch.qos.logback.classic.Logger) LoggerFactory.getLogger(ChannelPool.class);
46+
public static ch.qos.logback.classic.Logger reconnectionLogger =
47+
(ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Reconnection.class);
48+
ListAppender<ILoggingEvent> appender;
49+
Level originalLevelChannelPool;
50+
Level originalLevelReconnection;
51+
private final Pattern shardMismatchPattern =
52+
Pattern.compile(".*r configuration of shard aware port.*");
53+
private final Pattern reconnectionPattern =
54+
Pattern.compile(".*Scheduling next reconnection in.*");
55+
Set<Pattern> forbiddenOccurences = ImmutableSet.of(shardMismatchPattern, reconnectionPattern);
56+
57+
@Before
58+
public void startCapturingLogs() {
59+
originalLevelChannelPool = channelPoolLogger.getLevel();
60+
originalLevelReconnection = reconnectionLogger.getLevel();
61+
channelPoolLogger.setLevel(Level.DEBUG);
62+
reconnectionLogger.setLevel(Level.DEBUG);
63+
appender = new ListAppender<>();
64+
appender.setContext(
65+
((Logger) LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME)).getLoggerContext());
66+
channelPoolLogger.addAppender(appender);
67+
reconnectionLogger.addAppender(appender);
68+
appender.list.clear();
69+
appender.start();
70+
}
71+
72+
@After
73+
public void stopCapturingLogs() {
74+
appender.stop();
75+
appender.list.clear();
76+
channelPoolLogger.setLevel(originalLevelChannelPool);
77+
reconnectionLogger.setLevel(originalLevelReconnection);
78+
channelPoolLogger.detachAppender(appender);
79+
reconnectionLogger.detachAppender(appender);
80+
}
81+
82+
@Test
83+
public void should_initialize_all_channels() {
84+
Map<Pattern, Integer> expectedOccurences =
85+
ImmutableMap.of(
86+
Pattern.compile(
87+
".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 6/6 channels.*"),
88+
1,
89+
Pattern.compile(
90+
".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 6/6 channels.*"),
91+
1,
92+
Pattern.compile(".*Reconnection attempt complete.*"), 2,
93+
Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"), 5,
94+
Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"), 5,
95+
Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 5 missing channels.*"), 1,
96+
Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 5 missing channels.*"), 1);
97+
DriverConfigLoader loader =
98+
SessionUtils.configLoaderBuilder()
99+
.withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
100+
.withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000)
101+
.withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000)
102+
// Due to rounding up the connections per shard this will result in 6 connections per
103+
// node
104+
.withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 4)
105+
.build();
106+
try (Session session =
107+
CqlSession.builder()
108+
.addContactPoint(
109+
new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042))
110+
.withConfigLoader(loader)
111+
.build()) {
112+
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
113+
expectedOccurences.forEach(
114+
(pattern, times) -> assertMatchesExactly(pattern, times, appender.list));
115+
forbiddenOccurences.forEach(pattern -> assertNoLogMatches(pattern, appender.list));
116+
}
117+
}
118+
119+
@Test
120+
public void should_see_mismatched_shard() {
121+
DriverConfigLoader loader =
122+
SessionUtils.configLoaderBuilder()
123+
.withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
124+
.withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000)
125+
.withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000)
126+
.withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64)
127+
.build();
128+
try (Session session =
129+
CqlSession.builder()
130+
.addContactPoint(
131+
new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042))
132+
.withConfigLoader(loader)
133+
.build()) {
134+
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);
135+
assertMatchesAtLeast(shardMismatchPattern, 5, appender.list);
136+
}
137+
}
138+
139+
// There is no need to run this as a test, but it serves as a comparison
140+
@SuppressWarnings("unused")
141+
public void should_struggle_to_fill_pools() {
142+
DriverConfigLoader loader =
143+
SessionUtils.configLoaderBuilder()
144+
.withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, false)
145+
.withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64)
146+
.withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(200))
147+
.withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(4000))
148+
.build();
149+
CqlSessionBuilder builder =
150+
CqlSession.builder()
151+
.addContactPoint(
152+
new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042))
153+
.withConfigLoader(loader);
154+
CompletionStage<CqlSession> stage1 = builder.buildAsync();
155+
CompletionStage<CqlSession> stage2 = builder.buildAsync();
156+
CompletionStage<CqlSession> stage3 = builder.buildAsync();
157+
CompletionStage<CqlSession> stage4 = builder.buildAsync();
158+
try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1);
159+
CqlSession session2 = CompletableFutures.getUninterruptibly(stage2);
160+
CqlSession session3 = CompletableFutures.getUninterruptibly(stage3);
161+
CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) {
162+
Uninterruptibles.sleepUninterruptibly(20, TimeUnit.SECONDS);
163+
assertNoLogMatches(shardMismatchPattern, appender.list);
164+
assertMatchesAtLeast(reconnectionPattern, 8, appender.list);
165+
}
166+
}
167+
168+
@Test
169+
public void should_not_struggle_to_fill_pools() {
170+
DriverConfigLoader loader =
171+
SessionUtils.configLoaderBuilder()
172+
.withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true)
173+
.withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66)
174+
.withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(10))
175+
.withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(20))
176+
.build();
177+
CqlSessionBuilder builder =
178+
CqlSession.builder()
179+
.addContactPoint(
180+
new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042))
181+
.withConfigLoader(loader);
182+
CompletionStage<CqlSession> stage1 = builder.buildAsync();
183+
CompletionStage<CqlSession> stage2 = builder.buildAsync();
184+
CompletionStage<CqlSession> stage3 = builder.buildAsync();
185+
CompletionStage<CqlSession> stage4 = builder.buildAsync();
186+
int sessions = 4;
187+
try (CqlSession session1 = CompletableFutures.getUninterruptibly(stage1);
188+
CqlSession session2 = CompletableFutures.getUninterruptibly(stage2);
189+
CqlSession session3 = CompletableFutures.getUninterruptibly(stage3);
190+
CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) {
191+
Uninterruptibles.sleepUninterruptibly(8, TimeUnit.SECONDS);
192+
int tolerance = 2; // Sometimes socket ends up already in use
193+
Map<Pattern, Integer> expectedOccurences =
194+
ImmutableMap.of(
195+
Pattern.compile(
196+
".*127\\.0\\.0\\.2:19042.*Reconnection attempt complete, 66/66 channels.*"),
197+
1 * sessions,
198+
Pattern.compile(
199+
".*127\\.0\\.0\\.1:19042.*Reconnection attempt complete, 66/66 channels.*"),
200+
1 * sessions,
201+
Pattern.compile(".*Reconnection attempt complete.*"), 2 * sessions,
202+
Pattern.compile(".*127\\.0\\.0\\.1:19042.*New channel added \\[.*"),
203+
65 * sessions - tolerance,
204+
Pattern.compile(".*127\\.0\\.0\\.2:19042.*New channel added \\[.*"),
205+
65 * sessions - tolerance,
206+
Pattern.compile(".*127\\.0\\.0\\.1:19042\\] Trying to create 65 missing channels.*"),
207+
1 * sessions,
208+
Pattern.compile(".*127\\.0\\.0\\.2:19042\\] Trying to create 65 missing channels.*"),
209+
1 * sessions);
210+
expectedOccurences.forEach(
211+
(pattern, times) -> assertMatchesAtLeast(pattern, times, appender.list));
212+
assertNoLogMatches(shardMismatchPattern, appender.list);
213+
assertMatchesAtMost(reconnectionPattern, tolerance, appender.list);
214+
}
215+
}
216+
217+
private void assertNoLogMatches(Pattern pattern, List<ILoggingEvent> logs) {
218+
for (ILoggingEvent log : logs) {
219+
if (pattern.matcher(log.getFormattedMessage()).matches()) {
220+
fail(
221+
"Logs should not contain pattern ["
222+
+ pattern.toString()
223+
+ "] but found in ["
224+
+ log.getFormattedMessage()
225+
+ "]");
226+
}
227+
}
228+
}
229+
230+
private void assertMatchesExactly(Pattern pattern, Integer times, List<ILoggingEvent> logs) {
231+
int occurences = 0;
232+
for (ILoggingEvent log : logs) {
233+
if (pattern.matcher(log.getFormattedMessage()).matches()) {
234+
occurences++;
235+
}
236+
}
237+
if (occurences != times) {
238+
fail(
239+
"Expected to find pattern exactly "
240+
+ times
241+
+ " times but found it "
242+
+ occurences
243+
+ " times. Pattern: ["
244+
+ pattern.toString()
245+
+ "]");
246+
}
247+
}
248+
249+
private void assertMatchesAtLeast(Pattern pattern, Integer times, List<ILoggingEvent> logs) {
250+
int occurences = 0;
251+
for (ILoggingEvent log : logs) {
252+
if (pattern.matcher(log.getFormattedMessage()).matches()) {
253+
occurences++;
254+
if (occurences >= times) {
255+
return;
256+
}
257+
}
258+
}
259+
fail(
260+
"Expected to find pattern at least "
261+
+ times
262+
+ " times but found only "
263+
+ occurences
264+
+ " times. Pattern: ["
265+
+ pattern.toString()
266+
+ "]");
267+
}
268+
269+
private void assertMatchesAtMost(Pattern pattern, Integer times, List<ILoggingEvent> logs) {
270+
int occurences = 0;
271+
for (ILoggingEvent log : logs) {
272+
if (pattern.matcher(log.getFormattedMessage()).matches()) {
273+
occurences++;
274+
if (occurences > times) {
275+
fail(
276+
"Expected to find pattern at most "
277+
+ times
278+
+ " times but found it "
279+
+ occurences
280+
+ " times. Pattern: ["
281+
+ pattern.toString()
282+
+ "]");
283+
}
284+
}
285+
}
286+
}
287+
}

0 commit comments

Comments
 (0)