Skip to content

Commit

Permalink
SNOW-858328 Add cache for schema evolution permission query (snowflak…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang authored and khsoneji committed Oct 12, 2023
1 parent b9b0c61 commit abfa7dc
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ public void stop() {
@Override
public void open(final Collection<TopicPartition> partitions) {
long startTime = System.currentTimeMillis();
partitions.forEach(
tp -> this.sink.startTask(Utils.tableName(tp.topic(), this.topic2table), tp));
this.sink.startPartitions(partitions, this.topic2table);
this.DYNAMIC_LOGGER.info(
"task opened with {} partitions, execution time: {} milliseconds",
partitions.size(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ public interface SnowflakeSinkService {
* @param tableName destination table name
* @param topicPartition TopicPartition passed from Kafka
*/
void startTask(String tableName, TopicPartition topicPartition);
void startPartition(String tableName, TopicPartition topicPartition);

/**
* Start a collection of TopicPartition. This should handle any configuration parsing and one-time
* setup of the task.
*
* @param partitions collection of topic partitions
* @param topic2Table a mapping from topic to table
*/
void startPartitions(Collection<TopicPartition> partitions, Map<String, String> topic2Table);

/**
* call pipe to insert a collections of JSON records will trigger time based flush
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private SnowflakeSinkServiceBuilder(SnowflakeConnectionService conn) {
* @return Builder instance
*/
public SnowflakeSinkServiceBuilder addTask(String tableName, TopicPartition topicPartition) {
this.service.startTask(tableName, topicPartition);
this.service.startPartition(tableName, topicPartition);
LOGGER.info(
"create new task in {} - table: {}, topicPartition: {}",
SnowflakeSinkService.class.getName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class SnowflakeSinkServiceV1 implements SnowflakeSinkService {
* @param topicPartition TopicPartition passed from Kafka
*/
@Override
public void startTask(final String tableName, final TopicPartition topicPartition) {
public void startPartition(final String tableName, final TopicPartition topicPartition) {
String stageName = Utils.stageName(conn.getConnectorName(), tableName);
String nameIndex = getNameIndex(topicPartition.topic(), topicPartition.partition());
if (pipes.containsKey(nameIndex)) {
Expand All @@ -119,6 +119,12 @@ public void startTask(final String tableName, final TopicPartition topicPartitio
}
}

@Override
public void startPartitions(
Collection<TopicPartition> partitions, Map<String, String> topic2Table) {
partitions.forEach(tp -> this.startPartition(Utils.tableName(tp.topic(), topic2Table), tp));
}

@Override
public void insert(final Collection<SinkRecord> records) {
// note that records can be empty
Expand Down Expand Up @@ -148,7 +154,7 @@ public void insert(SinkRecord record) {
"Topic: {} Partition: {} hasn't been initialized by OPEN " + "function",
record.topic(),
record.kafkaPartition());
startTask(
startPartition(
Utils.tableName(record.topic(), this.topic2TableMap),
new TopicPartition(record.topic(), record.kafkaPartition()));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.kafka.connector.internal.streaming;

import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_SIZE_BYTES_DEFAULT;
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_ROLE;
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.STREAMING_BUFFER_COUNT_RECORDS_DEFAULT;
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.STREAMING_BUFFER_FLUSH_TIME_DEFAULT_SEC;
import static com.snowflake.kafka.connector.internal.streaming.TopicPartitionChannel.NO_OFFSET_TOKEN_REGISTERED_IN_SNOWFLAKE;
Expand Down Expand Up @@ -94,6 +95,9 @@ public class SnowflakeSinkServiceV2 implements SnowflakeSinkService {
*/
private final Map<String, TopicPartitionChannel> partitionsToChannel;

// Cache for schema evolution
private final Map<String, Boolean> tableName2SchemaEvolutionPermission;

public SnowflakeSinkServiceV2(
SnowflakeConnectionService conn, Map<String, String> connectorConfig) {
if (conn == null || conn.isClosed()) {
Expand Down Expand Up @@ -122,6 +126,8 @@ public SnowflakeSinkServiceV2(
.getClient(this.connectorConfig);

this.partitionsToChannel = new HashMap<>();

this.tableName2SchemaEvolutionPermission = new HashMap<>();
}

@VisibleForTesting
Expand Down Expand Up @@ -159,6 +165,14 @@ public SnowflakeSinkServiceV2(
.getClient(this.connectorConfig);
this.enableSchematization = enableSchematization;
this.partitionsToChannel = partitionsToChannel;

this.tableName2SchemaEvolutionPermission = new HashMap<>();
if (this.topicToTableMap != null) {
this.topicToTableMap.forEach(
(topic, tableName) -> {
populateSchemaEvolutionPermissions(tableName);
});
}
}

/**
Expand All @@ -171,12 +185,33 @@ public SnowflakeSinkServiceV2(
* @param topicPartition TopicPartition passed from Kafka
*/
@Override
public void startTask(String tableName, TopicPartition topicPartition) {
public void startPartition(String tableName, TopicPartition topicPartition) {
// the table should be present before opening a channel so let's do a table existence check here
createTableIfNotExists(tableName);

// Create channel for the given partition
createStreamingChannelForTopicPartition(tableName, topicPartition);
createStreamingChannelForTopicPartition(
tableName, topicPartition, tableName2SchemaEvolutionPermission.get(tableName));
}

/**
* Initializes multiple Channels and partitionsToChannel maps with new instances of {@link
* TopicPartitionChannel}
*
* @param partitions collection of topic partition
* @param topic2Table map of topic to table name
*/
@Override
public void startPartitions(
Collection<TopicPartition> partitions, Map<String, String> topic2Table) {
partitions.forEach(
tp -> {
String tableName = Utils.tableName(tp.topic(), topic2Table);
createTableIfNotExists(tableName);

createStreamingChannelForTopicPartition(
tableName, tp, tableName2SchemaEvolutionPermission.get(tableName));
});
}

/**
Expand All @@ -186,7 +221,9 @@ public void startTask(String tableName, TopicPartition topicPartition) {
* presented or not.
*/
private void createStreamingChannelForTopicPartition(
final String tableName, final TopicPartition topicPartition) {
final String tableName,
final TopicPartition topicPartition,
boolean hasSchemaEvolutionPermission) {
final String partitionChannelKey =
partitionChannelKey(topicPartition.topic(), topicPartition.partition());
// Create new instance of TopicPartitionChannel which will always open the channel.
Expand All @@ -197,6 +234,7 @@ private void createStreamingChannelForTopicPartition(
topicPartition,
partitionChannelKey, // Streaming channel name
tableName,
hasSchemaEvolutionPermission,
new StreamingBufferThreshold(this.flushTimeSeconds, this.fileSizeBytes, this.recordNum),
this.connectorConfig,
this.kafkaRecordErrorReporter,
Expand Down Expand Up @@ -252,7 +290,7 @@ public void insert(SinkRecord record) {
"Topic: {} Partition: {} hasn't been initialized by OPEN function",
record.topic(),
record.kafkaPartition());
startTask(
startPartition(
Utils.tableName(record.topic(), this.topicToTableMap),
new TopicPartition(record.topic(), record.kafkaPartition()));
}
Expand Down Expand Up @@ -512,5 +550,22 @@ private void createTableIfNotExists(final String tableName) {
this.conn.createTable(tableName);
}
}

// Populate schema evolution cache if needed
populateSchemaEvolutionPermissions(tableName);
}

private void populateSchemaEvolutionPermissions(String tableName) {
if (!tableName2SchemaEvolutionPermission.containsKey(tableName)) {
if (enableSchematization) {
tableName2SchemaEvolutionPermission.put(
tableName,
conn != null
&& conn.hasSchemaEvolutionPermission(
tableName, connectorConfig.get(SNOWFLAKE_ROLE)));
} else {
tableName2SchemaEvolutionPermission.put(tableName, false);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG;
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ERRORS_TOLERANCE_CONFIG;
import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_ROLE;
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.DURATION_BETWEEN_GET_OFFSET_TOKEN_RETRY;
import static com.snowflake.kafka.connector.internal.streaming.StreamingUtils.MAX_GET_OFFSET_TOKEN_RETRIES;
import static java.time.temporal.ChronoUnit.SECONDS;
Expand Down Expand Up @@ -194,6 +193,7 @@ public TopicPartitionChannel(
topicPartition,
channelName,
tableName,
false, /* No schema evolution permission */
streamingBufferThreshold,
sfConnectorConfig,
kafkaRecordErrorReporter,
Expand All @@ -209,6 +209,8 @@ public TopicPartitionChannel(
* (TopicPartitionChannel)
* @param channelName channel Name which is deterministic for topic and partition
* @param tableName table to ingest in snowflake
* @param hasSchemaEvolutionPermission if the role has permission to perform schema evolution on
* the table
* @param streamingBufferThreshold bytes, count of records and flush time thresholds.
* @param sfConnectorConfig configuration set for snowflake connector
* @param kafkaRecordErrorReporter kafka errpr reporter for sending records to DLQ
Expand All @@ -223,6 +225,7 @@ public TopicPartitionChannel(
TopicPartition topicPartition,
final String channelName,
final String tableName,
boolean hasSchemaEvolutionPermission,
final BufferThreshold streamingBufferThreshold,
final Map<String, String> sfConnectorConfig,
KafkaRecordErrorReporter kafkaRecordErrorReporter,
Expand Down Expand Up @@ -257,11 +260,8 @@ public TopicPartitionChannel(
/* Schematization related properties */
this.enableSchematization =
this.recordService.setAndGetEnableSchematizationFromConfig(sfConnectorConfig);
this.enableSchemaEvolution =
this.enableSchematization
&& this.conn != null
&& this.conn.hasSchemaEvolutionPermission(
tableName, sfConnectorConfig.get(SNOWFLAKE_ROLE));

this.enableSchemaEvolution = this.enableSchematization && hasSchemaEvolutionPermission;

// Open channel and reset the offset in kafka
this.channel = Preconditions.checkNotNull(openChannelForTable());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,15 @@ private static void testTopicToTableRegexRunner(
sinkTask.open(testPartitions);

// verify expected num tasks opened
Mockito.verify(serviceSpy, Mockito.times(expectedTopic2TableConfig.size()))
.startTask(Mockito.anyString(), Mockito.any(TopicPartition.class));
Mockito.verify(serviceSpy, Mockito.times(1))
.startPartitions(Mockito.anyCollection(), Mockito.anyMap());

for (String topicStr : expectedTopic2TableConfig.keySet()) {
TopicPartition topic = null;
String table = expectedTopic2TableConfig.get(topicStr);
for (TopicPartition currTp : testPartitions) {
if (currTp.topic().equals(topicStr)) {
topic = currTp;
Mockito.verify(serviceSpy, Mockito.times(1)).startTask(table, topic);
}
}
Assert.assertNotNull("Expected topic partition was not opened by the tast", topic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ public void testSinkServiceNegative() {
SchemaAndValue input =
converter.toConnectData(topic, "{\"name\":\"test\"}".getBytes(StandardCharsets.UTF_8));
service.insert(new SinkRecord(topic, partition, null, null, input.schema(), input.value(), 0));
service.startTask(table, new TopicPartition(topic, partition));
service.startPartition(table, new TopicPartition(topic, partition));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,126 @@ public void testStreamingIngest_multipleChannelPartitions() throws Exception {
service.closeAll();
}

@Test
public void testStreamingIngest_multipleChannelPartitionsWithTopic2Table() throws Exception {
final int partitionCount = 3;
final int recordsInEachPartition = 2;
final int topicCount = 3;

Map<String, String> config = TestUtils.getConfForStreaming();
SnowflakeSinkConnectorConfig.setDefaultValues(config);

ArrayList<String> topics = new ArrayList<>();
for (int topic = 0; topic < topicCount; topic++) {
topics.add(TestUtils.randomTableName());
}

// only insert fist topic to topicTable
Map<String, String> topic2Table = new HashMap<>();
topic2Table.put(topics.get(0), table);

SnowflakeSinkService service =
SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config)
.setRecordNumber(5)
.setFlushTime(5)
.setErrorReporter(new InMemoryKafkaRecordErrorReporter())
.setTopic2TableMap(topic2Table)
.setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition)))
.build();

for (int topic = 0; topic < topicCount; topic++) {
for (int partition = 0; partition < partitionCount; partition++) {
service.startPartition(topics.get(topic), new TopicPartition(topics.get(topic), partition));
}

List<SinkRecord> records = new ArrayList<>();
for (int partition = 0; partition < partitionCount; partition++) {
records.addAll(
TestUtils.createJsonStringSinkRecords(
0, recordsInEachPartition, topics.get(topic), partition));
}

service.insert(records);
}

for (int topic = 0; topic < topicCount; topic++) {
int finalTopic = topic;
TestUtils.assertWithRetry(
() -> {
service.insert(new ArrayList<>()); // trigger time based flush
return TestUtils.tableSize(topics.get(finalTopic))
== recordsInEachPartition * partitionCount;
},
10,
20);

for (int partition = 0; partition < partitionCount; partition++) {
int finalPartition = partition;
TestUtils.assertWithRetry(
() ->
service.getOffset(new TopicPartition(topics.get(finalTopic), finalPartition))
== recordsInEachPartition,
20,
5);
}
}

service.closeAll();
}

@Test
public void testStreamingIngest_startPartitionsWithMultipleChannelPartitions() throws Exception {
final int partitionCount = 5;
final int recordsInEachPartition = 2;

Map<String, String> config = TestUtils.getConfForStreaming();
SnowflakeSinkConnectorConfig.setDefaultValues(config);

SnowflakeSinkService service =
SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config)
.setRecordNumber(5)
.setFlushTime(5)
.setErrorReporter(new InMemoryKafkaRecordErrorReporter())
.setSinkTaskContext(new InMemorySinkTaskContext(Collections.singleton(topicPartition)))
.build();

ArrayList<TopicPartition> topicPartitions = new ArrayList<>();
for (int partition = 0; partition < partitionCount; partition++) {
topicPartitions.add(new TopicPartition(topic, partition));
}
Map<String, String> topic2Table = new HashMap<>();
topic2Table.put(topic, table);
service.startPartitions(topicPartitions, topic2Table);

List<SinkRecord> records = new ArrayList<>();
for (int partition = 0; partition < partitionCount; partition++) {
records.addAll(
TestUtils.createJsonStringSinkRecords(0, recordsInEachPartition, topic, partition));
}

service.insert(records);

TestUtils.assertWithRetry(
() -> {
service.insert(new ArrayList<>()); // trigger time based flush
return TestUtils.tableSize(table) == recordsInEachPartition * partitionCount;
},
10,
20);

for (int partition = 0; partition < partitionCount; partition++) {
int finalPartition = partition;
TestUtils.assertWithRetry(
() ->
service.getOffset(new TopicPartition(topic, finalPartition))
== recordsInEachPartition,
20,
5);
}

service.closeAll();
}

@Test
public void testStreamingIngestion_timeBased() throws Exception {
Map<String, String> config = TestUtils.getConfForStreaming();
Expand Down
Loading

0 comments on commit abfa7dc

Please sign in to comment.