Skip to content

Commit abfa7dc

Browse files
sfc-gh-alhuangkhsoneji
authored andcommitted
SNOW-858328 Add cache for schema evolution permission query (snowflakedb#683)
1 parent b9b0c61 commit abfa7dc

File tree

10 files changed

+211
-20
lines changed

10 files changed

+211
-20
lines changed

src/main/java/com/snowflake/kafka/connector/SnowflakeSinkTask.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,7 @@ public void stop() {
258258
@Override
259259
public void open(final Collection<TopicPartition> partitions) {
260260
long startTime = System.currentTimeMillis();
261-
partitions.forEach(
262-
tp -> this.sink.startTask(Utils.tableName(tp.topic(), this.topic2table), tp));
261+
this.sink.startPartitions(partitions, this.topic2table);
263262
this.DYNAMIC_LOGGER.info(
264263
"task opened with {} partitions, execution time: {} milliseconds",
265264
partitions.size(),

src/main/java/com/snowflake/kafka/connector/internal/SnowflakeSinkService.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ public interface SnowflakeSinkService {
2020
* @param tableName destination table name
2121
* @param topicPartition TopicPartition passed from Kafka
2222
*/
23-
void startTask(String tableName, TopicPartition topicPartition);
23+
void startPartition(String tableName, TopicPartition topicPartition);
24+
25+
/**
26+
* Start a collection of TopicPartition. This should handle any configuration parsing and one-time
27+
* setup of the task.
28+
*
29+
* @param partitions collection of topic partitions
30+
* @param topic2Table a mapping from topic to table
31+
*/
32+
void startPartitions(Collection<TopicPartition> partitions, Map<String, String> topic2Table);
2433

2534
/**
2635
* call pipe to insert a collections of JSON records will trigger time based flush

src/main/java/com/snowflake/kafka/connector/internal/SnowflakeSinkServiceFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ private SnowflakeSinkServiceBuilder(SnowflakeConnectionService conn) {
6868
* @return Builder instance
6969
*/
7070
public SnowflakeSinkServiceBuilder addTask(String tableName, TopicPartition topicPartition) {
71-
this.service.startTask(tableName, topicPartition);
71+
this.service.startPartition(tableName, topicPartition);
7272
LOGGER.info(
7373
"create new task in {} - table: {}, topicPartition: {}",
7474
SnowflakeSinkService.class.getName(),

src/main/java/com/snowflake/kafka/connector/internal/SnowflakeSinkServiceV1.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class SnowflakeSinkServiceV1 implements SnowflakeSinkService {
104104
* @param topicPartition TopicPartition passed from Kafka
105105
*/
106106
@Override
107-
public void startTask(final String tableName, final TopicPartition topicPartition) {
107+
public void startPartition(final String tableName, final TopicPartition topicPartition) {
108108
String stageName = Utils.stageName(conn.getConnectorName(), tableName);
109109
String nameIndex = getNameIndex(topicPartition.topic(), topicPartition.partition());
110110
if (pipes.containsKey(nameIndex)) {
@@ -119,6 +119,12 @@ public void startTask(final String tableName, final TopicPartition topicPartitio
119119
}
120120
}
121121

122+
@Override
123+
public void startPartitions(
124+
Collection<TopicPartition> partitions, Map<String, String> topic2Table) {
125+
partitions.forEach(tp -> this.startPartition(Utils.tableName(tp.topic(), topic2Table), tp));
126+
}
127+
122128
@Override
123129
public void insert(final Collection<SinkRecord> records) {
124130
// note that records can be empty
@@ -148,7 +154,7 @@ public void insert(SinkRecord record) {
148154
"Topic: {} Partition: {} hasn't been initialized by OPEN " + "function",
149155
record.topic(),
150156
record.kafkaPartition());
151-
startTask(
157+
startPartition(
152158
Utils.tableName(record.topic(), this.topic2TableMap),
153159
new TopicPartition(record.topic(), record.kafkaPartition()));
154160
}

src/main/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2.java

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.snowflake.kafka.connector.internal.streaming;
22

33
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_SIZE_BYTES_DEFAULT;
4+
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_ROLE;
45
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.STREAMING_BUFFER_COUNT_RECORDS_DEFAULT;
56
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.STREAMING_BUFFER_FLUSH_TIME_DEFAULT_SEC;
67
import static com.snowflake.kafka.connector.internal.streaming.TopicPartitionChannel.NO_OFFSET_TOKEN_REGISTERED_IN_SNOWFLAKE;
@@ -94,6 +95,9 @@ public class SnowflakeSinkServiceV2 implements SnowflakeSinkService {
9495
*/
9596
private final Map<String, TopicPartitionChannel> partitionsToChannel;
9697

98+
// Cache for schema evolution
99+
private final Map<String, Boolean> tableName2SchemaEvolutionPermission;
100+
97101
public SnowflakeSinkServiceV2(
98102
SnowflakeConnectionService conn, Map<String, String> connectorConfig) {
99103
if (conn == null || conn.isClosed()) {
@@ -122,6 +126,8 @@ public SnowflakeSinkServiceV2(
122126
.getClient(this.connectorConfig);
123127

124128
this.partitionsToChannel = new HashMap<>();
129+
130+
this.tableName2SchemaEvolutionPermission = new HashMap<>();
125131
}
126132

127133
@VisibleForTesting
@@ -159,6 +165,14 @@ public SnowflakeSinkServiceV2(
159165
.getClient(this.connectorConfig);
160166
this.enableSchematization = enableSchematization;
161167
this.partitionsToChannel = partitionsToChannel;
168+
169+
this.tableName2SchemaEvolutionPermission = new HashMap<>();
170+
if (this.topicToTableMap != null) {
171+
this.topicToTableMap.forEach(
172+
(topic, tableName) -> {
173+
populateSchemaEvolutionPermissions(tableName);
174+
});
175+
}
162176
}
163177

164178
/**
@@ -171,12 +185,33 @@ public SnowflakeSinkServiceV2(
171185
* @param topicPartition TopicPartition passed from Kafka
172186
*/
173187
@Override
174-
public void startTask(String tableName, TopicPartition topicPartition) {
188+
public void startPartition(String tableName, TopicPartition topicPartition) {
175189
// the table should be present before opening a channel so let's do a table existence check here
176190
createTableIfNotExists(tableName);
177191

178192
// Create channel for the given partition
179-
createStreamingChannelForTopicPartition(tableName, topicPartition);
193+
createStreamingChannelForTopicPartition(
194+
tableName, topicPartition, tableName2SchemaEvolutionPermission.get(tableName));
195+
}
196+
197+
/**
198+
* Initializes multiple Channels and partitionsToChannel maps with new instances of {@link
199+
* TopicPartitionChannel}
200+
*
201+
* @param partitions collection of topic partition
202+
* @param topic2Table map of topic to table name
203+
*/
204+
@Override
205+
public void startPartitions(
206+
Collection<TopicPartition> partitions, Map<String, String> topic2Table) {
207+
partitions.forEach(
208+
tp -> {
209+
String tableName = Utils.tableName(tp.topic(), topic2Table);
210+
createTableIfNotExists(tableName);
211+
212+
createStreamingChannelForTopicPartition(
213+
tableName, tp, tableName2SchemaEvolutionPermission.get(tableName));
214+
});
180215
}
181216

182217
/**
@@ -186,7 +221,9 @@ public void startTask(String tableName, TopicPartition topicPartition) {
186221
* presented or not.
187222
*/
188223
private void createStreamingChannelForTopicPartition(
189-
final String tableName, final TopicPartition topicPartition) {
224+
final String tableName,
225+
final TopicPartition topicPartition,
226+
boolean hasSchemaEvolutionPermission) {
190227
final String partitionChannelKey =
191228
partitionChannelKey(topicPartition.topic(), topicPartition.partition());
192229
// Create new instance of TopicPartitionChannel which will always open the channel.
@@ -197,6 +234,7 @@ private void createStreamingChannelForTopicPartition(
197234
topicPartition,
198235
partitionChannelKey, // Streaming channel name
199236
tableName,
237+
hasSchemaEvolutionPermission,
200238
new StreamingBufferThreshold(this.flushTimeSeconds, this.fileSizeBytes, this.recordNum),
201239
this.connectorConfig,
202240
this.kafkaRecordErrorReporter,
@@ -252,7 +290,7 @@ public void insert(SinkRecord record) {
252290
"Topic: {} Partition: {} hasn't been initialized by OPEN function",
253291
record.topic(),
254292
record.kafkaPartition());
255-
startTask(
293+
startPartition(
256294
Utils.tableName(record.topic(), this.topicToTableMap),
257295
new TopicPartition(record.topic(), record.kafkaPartition()));
258296
}
@@ -512,5 +550,22 @@ private void createTableIfNotExists(final String tableName) {
512550
this.conn.createTable(tableName);
513551
}
514552
}
553+
554+
// Populate schema evolution cache if needed
555+
populateSchemaEvolutionPermissions(tableName);
556+
}
557+
558+
private void populateSchemaEvolutionPermissions(String tableName) {
559+
if (!tableName2SchemaEvolutionPermission.containsKey(tableName)) {
560+
if (enableSchematization) {
561+
tableName2SchemaEvolutionPermission.put(
562+
tableName,
563+
conn != null
564+
&& conn.hasSchemaEvolutionPermission(
565+
tableName, connectorConfig.get(SNOWFLAKE_ROLE)));
566+
} else {
567+
tableName2SchemaEvolutionPermission.put(tableName, false);
568+
}
569+
}
515570
}
516571
}

src/main/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannel.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG;
44
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_TOLERANCE_CONFIG;
5-
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_ROLE;
65
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.DURATION_BETWEEN_GET_OFFSET_TOKEN_RETRY;
76
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.MAX_GET_OFFSET_TOKEN_RETRIES;
87
import static java.time.temporal.ChronoUnit.SECONDS;
@@ -194,6 +193,7 @@ public TopicPartitionChannel(
194193
topicPartition,
195194
channelName,
196195
tableName,
196+
false, /* No schema evolution permission */
197197
streamingBufferThreshold,
198198
sfConnectorConfig,
199199
kafkaRecordErrorReporter,
@@ -209,6 +209,8 @@ public TopicPartitionChannel(
209209
* (TopicPartitionChannel)
210210
* @param channelName channel Name which is deterministic for topic and partition
211211
* @param tableName table to ingest in snowflake
212+
* @param hasSchemaEvolutionPermission if the role has permission to perform schema evolution on
213+
* the table
212214
* @param streamingBufferThreshold bytes, count of records and flush time thresholds.
213215
* @param sfConnectorConfig configuration set for snowflake connector
214216
* @param kafkaRecordErrorReporter kafka errpr reporter for sending records to DLQ
@@ -223,6 +225,7 @@ public TopicPartitionChannel(
223225
TopicPartition topicPartition,
224226
final String channelName,
225227
final String tableName,
228+
boolean hasSchemaEvolutionPermission,
226229
final BufferThreshold streamingBufferThreshold,
227230
final Map<String, String> sfConnectorConfig,
228231
KafkaRecordErrorReporter kafkaRecordErrorReporter,
@@ -257,11 +260,8 @@ public TopicPartitionChannel(
257260
/* Schematization related properties */
258261
this.enableSchematization =
259262
this.recordService.setAndGetEnableSchematizationFromConfig(sfConnectorConfig);
260-
this.enableSchemaEvolution =
261-
this.enableSchematization
262-
&& this.conn != null
263-
&& this.conn.hasSchemaEvolutionPermission(
264-
tableName, sfConnectorConfig.get(SNOWFLAKE_ROLE));
263+
264+
this.enableSchemaEvolution = this.enableSchematization && hasSchemaEvolutionPermission;
265265

266266
// Open channel and reset the offset in kafka
267267
this.channel = Preconditions.checkNotNull(openChannelForTable());

src/test/java/com/snowflake/kafka/connector/SnowflakeSinkTaskForStreamingIT.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,15 @@ private static void testTopicToTableRegexRunner(
312312
sinkTask.open(testPartitions);
313313

314314
// verify expected num tasks opened
315-
Mockito.verify(serviceSpy, Mockito.times(expectedTopic2TableConfig.size()))
316-
.startTask(Mockito.anyString(), Mockito.any(TopicPartition.class));
315+
Mockito.verify(serviceSpy, Mockito.times(1))
316+
.startPartitions(Mockito.anyCollection(), Mockito.anyMap());
317317

318318
for (String topicStr : expectedTopic2TableConfig.keySet()) {
319319
TopicPartition topic = null;
320320
String table = expectedTopic2TableConfig.get(topicStr);
321321
for (TopicPartition currTp : testPartitions) {
322322
if (currTp.topic().equals(topicStr)) {
323323
topic = currTp;
324-
Mockito.verify(serviceSpy, Mockito.times(1)).startTask(table, topic);
325324
}
326325
}
327326
Assert.assertNotNull("Expected topic partition was not opened by the tast", topic);

src/test/java/com/snowflake/kafka/connector/internal/SinkServiceIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ public void testSinkServiceNegative() {
888888
SchemaAndValue input =
889889
converter.toConnectData(topic, "{\"name\":\"test\"}".getBytes(StandardCharsets.UTF_8));
890890
service.insert(new SinkRecord(topic, partition, null, null, input.schema(), input.value(), 0));
891-
service.startTask(table, new TopicPartition(topic, partition));
891+
service.startPartition(table, new TopicPartition(topic, partition));
892892
}
893893

894894
@Test

src/test/java/com/snowflake/kafka/connector/internal/streaming/SnowflakeSinkServiceV2IT.java

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,126 @@ public void testStreamingIngest_multipleChannelPartitions() throws Exception {
369369
service.closeAll();
370370
}
371371

372+
@Test
373+
public void testStreamingIngest_multipleChannelPartitionsWithTopic2Table() throws Exception {
374+
final int partitionCount = 3;
375+
final int recordsInEachPartition = 2;
376+
final int topicCount = 3;
377+
378+
Map<String, String> config = TestUtils.getConfForStreaming();
379+
SnowflakeSinkConnectorConfig.setDefaultValues(config);
380+
381+
ArrayList<String> topics = new ArrayList<>();
382+
for (int topic = 0; topic < topicCount; topic++) {
383+
topics.add(TestUtils.randomTableName());
384+
}
385+
386+
// only insert fist topic to topicTable
387+
Map<String, String> topic2Table = new HashMap<>();
388+
topic2Table.put(topics.get(0), table);
389+
390+
SnowflakeSinkService service =
391+
SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config)
392+
.setRecordNumber(5)
393+
.setFlushTime(5)
394+
.setErrorReporter(new InMemoryKafkaRecordErrorReporter())
395+
.setTopic2TableMap(topic2Table)
396+
.setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition)))
397+
.build();
398+
399+
for (int topic = 0; topic < topicCount; topic++) {
400+
for (int partition = 0; partition < partitionCount; partition++) {
401+
service.startPartition(topics.get(topic), new TopicPartition(topics.get(topic), partition));
402+
}
403+
404+
List<SinkRecord> records = new ArrayList<>();
405+
for (int partition = 0; partition < partitionCount; partition++) {
406+
records.addAll(
407+
TestUtils.createJsonStringSinkRecords(
408+
0, recordsInEachPartition, topics.get(topic), partition));
409+
}
410+
411+
service.insert(records);
412+
}
413+
414+
for (int topic = 0; topic < topicCount; topic++) {
415+
int finalTopic = topic;
416+
TestUtils.assertWithRetry(
417+
() -> {
418+
service.insert(new ArrayList<>()); // trigger time based flush
419+
return TestUtils.tableSize(topics.get(finalTopic))
420+
== recordsInEachPartition * partitionCount;
421+
},
422+
10,
423+
20);
424+
425+
for (int partition = 0; partition < partitionCount; partition++) {
426+
int finalPartition = partition;
427+
TestUtils.assertWithRetry(
428+
() ->
429+
service.getOffset(new TopicPartition(topics.get(finalTopic), finalPartition))
430+
== recordsInEachPartition,
431+
20,
432+
5);
433+
}
434+
}
435+
436+
service.closeAll();
437+
}
438+
439+
@Test
440+
public void testStreamingIngest_startPartitionsWithMultipleChannelPartitions() throws Exception {
441+
final int partitionCount = 5;
442+
final int recordsInEachPartition = 2;
443+
444+
Map<String, String> config = TestUtils.getConfForStreaming();
445+
SnowflakeSinkConnectorConfig.setDefaultValues(config);
446+
447+
SnowflakeSinkService service =
448+
SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config)
449+
.setRecordNumber(5)
450+
.setFlushTime(5)
451+
.setErrorReporter(new InMemoryKafkaRecordErrorReporter())
452+
.setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition)))
453+
.build();
454+
455+
ArrayList<TopicPartition> topicPartitions = new ArrayList<>();
456+
for (int partition = 0; partition < partitionCount; partition++) {
457+
topicPartitions.add(new TopicPartition(topic, partition));
458+
}
459+
Map<String, String> topic2Table = new HashMap<>();
460+
topic2Table.put(topic, table);
461+
service.startPartitions(topicPartitions, topic2Table);
462+
463+
List<SinkRecord> records = new ArrayList<>();
464+
for (int partition = 0; partition < partitionCount; partition++) {
465+
records.addAll(
466+
TestUtils.createJsonStringSinkRecords(0, recordsInEachPartition, topic, partition));
467+
}
468+
469+
service.insert(records);
470+
471+
TestUtils.assertWithRetry(
472+
() -> {
473+
service.insert(new ArrayList<>()); // trigger time based flush
474+
return TestUtils.tableSize(table) == recordsInEachPartition * partitionCount;
475+
},
476+
10,
477+
20);
478+
479+
for (int partition = 0; partition < partitionCount; partition++) {
480+
int finalPartition = partition;
481+
TestUtils.assertWithRetry(
482+
() ->
483+
service.getOffset(new TopicPartition(topic, finalPartition))
484+
== recordsInEachPartition,
485+
20,
486+
5);
487+
}
488+
489+
service.closeAll();
490+
}
491+
372492
@Test
373493
public void testStreamingIngestion_timeBased() throws Exception {
374494
Map<String, String> config = TestUtils.getConfForStreaming();

0 commit comments

Comments
 (0)