diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannel.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannel.java index d6b2b8213..29b3ac10c 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannel.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannel.java @@ -507,8 +507,6 @@ public InsertRowsResponse insertRecords(StreamingBuffer streamingBufferToInsert) InsertRowsResponse response = null; try { response = insertRowsWithFallback(streamingBufferToInsert); - // Updates the flush time (last time we called insertRows API) - this.previousFlushTimeStampMs = System.currentTimeMillis(); LOGGER.info( "Successfully called insertRows for channel:{}, buffer:{}, insertResponseHasErrors:{}," @@ -517,10 +515,6 @@ public InsertRowsResponse insertRecords(StreamingBuffer streamingBufferToInsert) streamingBufferToInsert, response.hasErrors(), response.needToResetOffset()); - if (response.hasErrors()) { - handleInsertRowsFailures( - response.getInsertErrors(), streamingBufferToInsert.getSinkRecords()); - } // Due to schema evolution, we may need to reopen the channel and reset the offset in kafka // since it's possible that not all rows are ingested @@ -528,6 +522,18 @@ public InsertRowsResponse insertRecords(StreamingBuffer streamingBufferToInsert) streamingApiFallbackSupplier( StreamingApiFallbackInvoker.INSERT_ROWS_SCHEMA_EVOLUTION_FALLBACK); } + // If there are errors other than schema mismatch, we need to handle them and reinsert the + // good rows + if (response.hasErrors()) { + handleInsertRowsFailures( + response.getInsertErrors(), streamingBufferToInsert.getSinkRecords()); + insertRecords( + rebuildBufferWithoutErrorRows(streamingBufferToInsert, response.getInsertErrors())); + } + + // Updates the flush time (last time we successfully insert some rows) + this.previousFlushTimeStampMs = System.currentTimeMillis(); + return response; } catch (TopicPartitionChannelInsertionException ex) { // Suppressing the exception because other channels might still continue to ingest @@ -540,6 +546,22 @@ public InsertRowsResponse insertRecords(StreamingBuffer streamingBufferToInsert) return response; } + /** Building a new buffer which contains only the good rows from the original buffer */ + private StreamingBuffer rebuildBufferWithoutErrorRows( + StreamingBuffer streamingBufferToInsert, + List insertErrors) { + StreamingBuffer buffer = new StreamingBuffer(); + int errorIdx = 0; + for (long rowIdx = 0; rowIdx < streamingBufferToInsert.getNumOfRecords(); rowIdx++) { + if (errorIdx < insertErrors.size() && rowIdx == insertErrors.get(errorIdx).getRowIndex()) { + errorIdx++; + } else { + buffer.insert(streamingBufferToInsert.getSinkRecord(rowIdx)); + } + } + return buffer; + } + /** * Uses {@link Fallback} API to reopen the channel if insertRows throws {@link SFException}. * @@ -620,65 +642,40 @@ public InsertRowsResponse get() throws Throwable { this.insertRowsStreamingBuffer); Pair>, List> recordsAndOffsets = this.insertRowsStreamingBuffer.getData(); - List> records = recordsAndOffsets.getKey(); - List offsets = recordsAndOffsets.getValue(); InsertValidationResponse finalResponse = new InsertValidationResponse(); + List> records = recordsAndOffsets.getKey(); boolean needToResetOffset = false; + InsertValidationResponse response = + this.channel.insertRows( + records, + Long.toString(this.insertRowsStreamingBuffer.getFirstOffset()), + Long.toString(this.insertRowsStreamingBuffer.getLastOffset())); if (!enableSchemaEvolution) { - finalResponse = - this.channel.insertRows( - records, - Long.toString(this.insertRowsStreamingBuffer.getFirstOffset()), - Long.toString(this.insertRowsStreamingBuffer.getLastOffset())); + finalResponse = response; } else { - for (int idx = 0; idx < records.size(); idx++) { - // For schema evolution, we need to call the insertRows API row by row in order to - // preserve the original order, for anything after the first schema mismatch error we will - // retry after the evolution - InsertValidationResponse response = - this.channel.insertRow(records.get(idx), Long.toString(offsets.get(idx))); - if (response.hasErrors()) { - InsertValidationResponse.InsertError insertError = response.getInsertErrors().get(0); - List extraColNames = insertError.getExtraColNames(); - - List missingNotNullColNames = insertError.getMissingNotNullColNames(); - Set nonNullableColumns = - new HashSet<>( - missingNotNullColNames != null - ? missingNotNullColNames - : Collections.emptySet()); - - List nullValueForNotNullColNames = insertError.getNullValueForNotNullColNames(); - nonNullableColumns.addAll( - nullValueForNotNullColNames != null - ? nullValueForNotNullColNames - : Collections.emptySet()); - - long originalSinkRecordIdx = - offsets.get(idx) - this.insertRowsStreamingBuffer.getFirstOffset(); - - if (extraColNames == null && nonNullableColumns.isEmpty()) { - InsertValidationResponse.InsertError newInsertError = - new InsertValidationResponse.InsertError( - insertError.getRowContent(), originalSinkRecordIdx); - newInsertError.setException(insertError.getException()); - newInsertError.setExtraColNames(insertError.getExtraColNames()); - newInsertError.setMissingNotNullColNames(insertError.getMissingNotNullColNames()); - newInsertError.setNullValueForNotNullColNames( - insertError.getNullValueForNotNullColNames()); - // Simply added to the final response if it's not schema related errors - finalResponse.addError(insertError); - } else { - SchematizationUtils.evolveSchemaIfNeeded( - this.conn, - this.channel.getTableName(), - new ArrayList<>(nonNullableColumns), - extraColNames, - this.insertRowsStreamingBuffer.getSinkRecord(originalSinkRecordIdx)); - // Offset reset needed since it's possible that we successfully ingested partial batch - needToResetOffset = true; - break; - } + for (InsertValidationResponse.InsertError insertError : response.getInsertErrors()) { + List extraColNames = insertError.getExtraColNames(); + List missingNotNullColNames = insertError.getMissingNotNullColNames(); + Set nonNullableColumns = + new HashSet<>( + missingNotNullColNames != null ? missingNotNullColNames : Collections.emptySet()); + + List nullValueForNotNullColNames = insertError.getNullValueForNotNullColNames(); + nonNullableColumns.addAll( + nullValueForNotNullColNames != null + ? nullValueForNotNullColNames + : Collections.emptySet()); + if (extraColNames != null || !nonNullableColumns.isEmpty()) { + SchematizationUtils.evolveSchemaIfNeeded( + this.conn, + this.channel.getTableName(), + new ArrayList<>(nonNullableColumns), + extraColNames, + this.insertRowsStreamingBuffer.getSinkRecord(insertError.getRowIndex())); + needToResetOffset = true; + } else { + // Simply added to the final response if it's not schema related errors + finalResponse.addError(insertError); } } } @@ -998,7 +995,7 @@ private SnowflakeStreamingIngestChannel openChannelForTable() { .setDBName(this.sfConnectorConfig.get(Utils.SF_DATABASE)) .setSchemaName(this.sfConnectorConfig.get(Utils.SF_SCHEMA)) .setTableName(this.tableName) - .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.SKIP_BATCH) .setOffsetTokenVerificationFunction(StreamingUtils.offsetTokenVerificationFunction) .build(); LOGGER.info( diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannelTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannelTest.java index 479b7deb8..9b3ffbf6e 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannelTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/BufferedTopicPartitionChannelTest.java @@ -10,10 +10,12 @@ import com.snowflake.kafka.connector.internal.BufferThreshold; import com.snowflake.kafka.connector.internal.SnowflakeConnectionService; import com.snowflake.kafka.connector.internal.TestUtils; +import com.snowflake.kafka.connector.internal.streaming.channel.TopicPartitionChannel; import com.snowflake.kafka.connector.internal.telemetry.SnowflakeTelemetryService; import com.snowflake.kafka.connector.records.RecordService; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -27,6 +29,7 @@ import org.apache.kafka.connect.errors.DataException; import org.apache.kafka.connect.sink.SinkRecord; import org.apache.kafka.connect.sink.SinkTaskContext; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -326,4 +329,107 @@ public void testInsertRows_ValidationResponseHasErrors_ErrorTolerance_ALL() thro assert kafkaRecordErrorReporter.getReportedRecords().size() == 1; } + + @Test + public void testInsertRowsWithSchemaEvolution() throws Exception { + if (this.sfConnectorConfig + .get(SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG) + .equals("true")) { + InsertValidationResponse notSchemaEvolutionErrorResponse = new InsertValidationResponse(); + InsertValidationResponse.InsertError notSchemaEvolutionError = + new InsertValidationResponse.InsertError("CONTENT", 0); + notSchemaEvolutionError.setException(SF_EXCEPTION); + notSchemaEvolutionErrorResponse.addError(notSchemaEvolutionError); + + InsertValidationResponse schemaEvolutionRecoverableErrorResponse = + new InsertValidationResponse(); + InsertValidationResponse.InsertError schemaEvolutionRecoverableError = + new InsertValidationResponse.InsertError("CONTENT", 0); + schemaEvolutionRecoverableError.setException(SF_EXCEPTION); + schemaEvolutionRecoverableError.setExtraColNames(Collections.singletonList("gender")); + schemaEvolutionRecoverableErrorResponse.addError(schemaEvolutionRecoverableError); + + Mockito.when( + mockStreamingChannel.insertRows( + ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any())) + .thenReturn(schemaEvolutionRecoverableErrorResponse) + .thenReturn(notSchemaEvolutionErrorResponse) + .thenReturn(new InsertValidationResponse()); // last insert with correct batch + + Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn("0"); + + SnowflakeConnectionService conn = Mockito.mock(SnowflakeConnectionService.class); + Mockito.when( + conn.hasSchemaEvolutionPermission(ArgumentMatchers.any(), ArgumentMatchers.any())) + .thenReturn(true); + Mockito.doNothing() + .when(conn) + .appendColumnsToTable(ArgumentMatchers.any(), ArgumentMatchers.any()); + + long bufferFlushTimeSeconds = 5L; + StreamingBufferThreshold bufferThreshold = + new StreamingBufferThreshold(bufferFlushTimeSeconds, 1_000 /* < 1KB */, 10000000L); + + Map sfConnectorConfigWithErrors = new HashMap<>(sfConnectorConfig); + sfConnectorConfigWithErrors.put( + ERRORS_TOLERANCE_CONFIG, SnowflakeSinkConnectorConfig.ErrorTolerance.ALL.toString()); + sfConnectorConfigWithErrors.put(ERRORS_DEAD_LETTER_QUEUE_TOPIC_NAME_CONFIG, "test_DLQ"); + InMemoryKafkaRecordErrorReporter kafkaRecordErrorReporter = + new InMemoryKafkaRecordErrorReporter(); + + TopicPartitionChannel topicPartitionChannel = + new BufferedTopicPartitionChannel( + mockStreamingClient, + topicPartition, + TEST_CHANNEL_NAME, + TEST_TABLE_NAME, + this.enableSchematization, + bufferThreshold, + sfConnectorConfigWithErrors, + kafkaRecordErrorReporter, + mockSinkTaskContext, + conn, + new RecordService(), + mockTelemetryService, + false, + null); + + final int noOfRecords = 3; + List records = + TestUtils.createNativeJsonSinkRecords(0, noOfRecords, TOPIC, PARTITION); + + for (int idx = 0; idx < records.size(); idx++) { + topicPartitionChannel.insertRecord(records.get(idx), idx == 0); + } + + // In an ideal world, put API is going to invoke this to check if flush time threshold has + // reached. + // We are mimicking that call. + // Will wait for 10 seconds. + Thread.sleep(bufferFlushTimeSeconds * 1000 + 10); + + topicPartitionChannel.insertBufferedRecordsIfFlushTimeThresholdReached(); + + // Verify that the buffer is cleaned up and nothing is in DLQ because of schematization error + Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); + Assert.assertEquals(0, kafkaRecordErrorReporter.getReportedRecords().size()); + + // Do it again without any schematization error, and we should have row in DLQ + for (int idx = 0; idx < records.size(); idx++) { + topicPartitionChannel.insertRecord(records.get(idx), idx == 0); + } + + // In an ideal world, put API is going to invoke this to check if flush time threshold has + // reached. + // We are mimicking that call. + // Will wait for 10 seconds. + Thread.sleep(bufferFlushTimeSeconds * 1000 + 10); + + topicPartitionChannel.insertBufferedRecordsIfFlushTimeThresholdReached(); + + // Verify that the buffer is cleaned up and one record is in the DLQ + Assert.assertTrue(topicPartitionChannel.isPartitionBufferEmpty()); + Assert.assertEquals(1, kafkaRecordErrorReporter.getReportedRecords().size()); + } + } } diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java index 9ab661530..7a6a58ee0 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/TopicPartitionChannelTest.java @@ -806,10 +806,11 @@ public void testInsertRows_SuccessAfterReopenChannel() throws Exception { } @Test - public void testInsertRowsWithSchemaEvolution() throws Exception { + public void testInsertRowsWithSchemaEvolution_onlySingleBuffer() throws Exception { if (this.sfConnectorConfig - .get(SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG) - .equals("true")) { + .get(SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG) + .equals("true") + && !useDoubleBuffer) { InsertValidationResponse validationResponse1 = new InsertValidationResponse(); InsertValidationResponse.InsertError insertError1 = new InsertValidationResponse.InsertError("CONTENT", 0);