Skip to content

Commit b0a4f0d

Browse files
authored
[vpj][test]Fix spark Catalyst Row Construction for Spark Raw PubSub Source (#2300)
Introduce a correctness fix in SparkPubSubInputPartitionReader to align internal row construction with Catalyst expectations. Region is now written as UTF8String instead of a Java String, and replication metadata fields are reordered to match RAW_PUBSUB_INPUT_TABLE_SCHEMA. This prevents Spark from casting String to UTF8String at runtime, which previously triggered a ClassCastException during projection and codegen. Row construction now properly matches Spark internal types, avoiding downstream DataFrame failures.
1 parent f2b82f6 commit b0a4f0d

File tree

3 files changed

+82
-24
lines changed

3 files changed

+82
-24
lines changed

clients/venice-push-job/src/main/java/com/linkedin/venice/spark/input/pubsub/SparkPubSubInputFormat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,9 @@ public PartitionReaderFactory createReaderFactory() {
6262
public StructType readSchema() {
6363
return null;
6464
}
65+
66+
@Override
67+
public Batch toBatch() {
68+
return this;
69+
}
6570
}

clients/venice-push-job/src/main/java/com/linkedin/venice/spark/input/pubsub/SparkPubSubInputPartitionReader.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.spark.sql.catalyst.InternalRow;
2222
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
2323
import org.apache.spark.sql.connector.read.PartitionReader;
24+
import org.apache.spark.unsafe.types.UTF8String;
2425

2526

2627
/**
@@ -105,6 +106,7 @@ public boolean next() throws IOException {
105106
Delete delete = (Delete) pubSubMessageValue.getPayloadUnion();
106107
schemaId = delete.getSchemaId();
107108
value = EMPTY_BYTE_BUFFER;
109+
108110
replicationMetadataPayload = delete.getReplicationMetadataPayload();
109111
replicationMetadataVersionId = delete.getReplicationMetadataVersionId();
110112
break;
@@ -116,11 +118,13 @@ public boolean next() throws IOException {
116118

117119
/**
118120
* See {@link com.linkedin.venice.spark.SparkConstants#RAW_PUBSUB_INPUT_TABLE_SCHEMA} for the schema definition.
121+
* Enforce the region to be UTF8String for Spark compatibility and additionally handle ordering of columns per
122+
* the schema.
119123
*/
120124
currentRow = new GenericInternalRow(
121-
new Object[] { region, topicPartition.getPartitionNumber(), messageType, rec.getOffset(), schemaId,
122-
ByteUtils.extractByteArray(key), ByteUtils.extractByteArray(value),
123-
ByteUtils.extractByteArray(replicationMetadataPayload), replicationMetadataVersionId });
125+
new Object[] { UTF8String.fromString(region), topicPartition.getPartitionNumber(), rec.getOffset(), messageType,
126+
schemaId, ByteUtils.extractByteArray(key), ByteUtils.extractByteArray(value), replicationMetadataVersionId,
127+
ByteUtils.extractByteArray(replicationMetadataPayload) });
124128

125129
logProgressPercent();
126130
return true;

clients/venice-push-job/src/test/java/com/linkedin/venice/spark/input/pubsub/SparkPubSubInputPartitionReaderTest.java

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.linkedin.venice.vpj.pubsub.input.PubSubSplitIterator;
2727
import java.io.IOException;
2828
import java.nio.ByteBuffer;
29+
import java.util.Arrays;
2930
import org.apache.spark.sql.catalyst.InternalRow;
3031
import org.testng.annotations.BeforeMethod;
3132
import org.testng.annotations.Test;
@@ -105,13 +106,12 @@ public void testNextWithPutMessageScenarios() throws IOException {
105106
assertTrue(reader.next(), "Reader should successfully process PUT message");
106107
InternalRow row = reader.get();
107108
assertNotNull(row, "Row should not be null after processing PUT message");
108-
// Note: The actual implementation has field order: region, partition, messageType, offset, schemaId, key, value,
109-
// replicationMetadataPayload, replicationMetadataVersionId
110-
// The real implementation incorrectly puts String instead of UTF8String, so we access it as Object
111-
assertEquals(row.get(0, org.apache.spark.sql.types.DataTypes.StringType), TEST_REGION, "Region should match");
109+
// Field order: region, partition, offset, messageType, schemaId, key, value,
110+
// replicationMetadataVersionId, replicationMetadataPayload
111+
assertEquals(row.getUTF8String(0).toString(), TEST_REGION, "Region should match");
112112
assertEquals(row.getInt(1), TEST_PARTITION_NUMBER, "Partition number should match");
113-
assertEquals(row.getInt(2), MessageType.PUT.getValue(), "Message type should be PUT");
114-
assertEquals(row.getLong(3), 100L, "Offset should match");
113+
assertEquals(row.getLong(2), 100L, "Offset should match");
114+
assertEquals(row.getInt(3), MessageType.PUT.getValue(), "Message type should be PUT");
115115
assertEquals(row.getInt(4), 1, "Schema ID should match");
116116

117117
assertFalse(reader.next(), "Reader should return false when no more messages");
@@ -127,7 +127,7 @@ public void testNextWithPutMessageScenarios() throws IOException {
127127
assertTrue(reader.next(), "Reader should successfully process PUT message with metadata");
128128
InternalRow rowWithMetadata = reader.get();
129129
assertNotNull(rowWithMetadata, "Row should not be null");
130-
assertEquals(rowWithMetadata.getInt(8), 5, "Replication metadata version ID should match");
130+
assertEquals(rowWithMetadata.getInt(7), 5, "Replication metadata version ID should match");
131131

132132
assertFalse(reader.next(), "Reader should return false when no more messages");
133133
reader.close();
@@ -157,13 +157,12 @@ public void testNextWithDeleteMessageScenarios() throws IOException {
157157
assertTrue(reader.next(), "Reader should successfully process DELETE message");
158158
InternalRow row = reader.get();
159159
assertNotNull(row, "Row should not be null after processing DELETE message");
160-
// Note: The actual implementation has field order: region, partition, messageType, offset, schemaId, key, value,
161-
// replicationMetadataPayload, replicationMetadataVersionId
162-
// The real implementation incorrectly puts String instead of UTF8String, so we access it as Object
163-
assertEquals(row.get(0, org.apache.spark.sql.types.DataTypes.StringType), TEST_REGION, "Region should match");
160+
// Field order: region, partition, offset, messageType, schemaId, key, value,
161+
// replicationMetadataVersionId, replicationMetadataPayload
162+
assertEquals(row.getUTF8String(0).toString(), TEST_REGION, "Region should match");
164163
assertEquals(row.getInt(1), TEST_PARTITION_NUMBER, "Partition number should match");
165-
assertEquals(row.getInt(2), MessageType.DELETE.getValue(), "Message type should be DELETE");
166-
assertEquals(row.getLong(3), 200L, "Offset should match");
164+
assertEquals(row.getLong(2), 200L, "Offset should match");
165+
assertEquals(row.getInt(3), MessageType.DELETE.getValue(), "Message type should be DELETE");
167166
assertEquals(row.getInt(4), 10, "Schema ID should match");
168167
assertEquals(row.getBinary(6).length, 0, "DELETE message should have empty value");
169168

@@ -180,7 +179,7 @@ public void testNextWithDeleteMessageScenarios() throws IOException {
180179
assertTrue(reader.next(), "Reader should successfully process DELETE message with metadata");
181180
InternalRow rowWithMetadata = reader.get();
182181
assertNotNull(rowWithMetadata, "Row should not be null");
183-
assertEquals(rowWithMetadata.getInt(8), 7, "Replication metadata version ID should match");
182+
assertEquals(rowWithMetadata.getInt(7), 7, "Replication metadata version ID should match");
184183

185184
assertFalse(reader.next(), "Reader should return false when no more messages");
186185
reader.close();
@@ -193,11 +192,11 @@ public void testNextWithDeleteMessageScenarios() throws IOException {
193192

194193
assertTrue(reader.next(), "Reader should process first DELETE message");
195194
InternalRow firstRow = reader.get();
196-
assertEquals(firstRow.getLong(3), 202L, "First DELETE offset should match");
195+
assertEquals(firstRow.getLong(2), 202L, "First DELETE offset should match");
197196

198197
assertTrue(reader.next(), "Reader should process second DELETE message");
199198
InternalRow secondRow = reader.get();
200-
assertEquals(secondRow.getLong(3), 203L, "Second DELETE offset should match");
199+
assertEquals(secondRow.getLong(2), 203L, "Second DELETE offset should match");
201200

202201
assertFalse(reader.next(), "Reader should return false when no more messages");
203202
reader.close();
@@ -261,7 +260,7 @@ public void testMessageTypeHandlingScenarios() throws IOException {
261260

262261
assertTrue(reader.next(), "Reader should handle PUT message");
263262
InternalRow putRow = reader.get();
264-
assertEquals(putRow.getInt(2), MessageType.PUT.getValue(), "Message type should be PUT");
263+
assertEquals(putRow.getInt(3), MessageType.PUT.getValue(), "Message type should be PUT");
265264
reader.close();
266265

267266
// Case 2: Reader handles DELETE message type correctly
@@ -271,7 +270,7 @@ public void testMessageTypeHandlingScenarios() throws IOException {
271270

272271
assertTrue(reader.next(), "Reader should handle DELETE message");
273272
InternalRow deleteRow = reader.get();
274-
assertEquals(deleteRow.getInt(2), MessageType.DELETE.getValue(), "Message type should be DELETE");
273+
assertEquals(deleteRow.getInt(3), MessageType.DELETE.getValue(), "Message type should be DELETE");
275274
reader.close();
276275

277276
// Case 3: Message type validation
@@ -283,10 +282,10 @@ public void testMessageTypeHandlingScenarios() throws IOException {
283282
when(mockSplitIterator.next()).thenReturn(putRecord2).thenReturn(deleteRecord2).thenReturn(null);
284283

285284
assertTrue(reader.next(), "Reader should process first message");
286-
assertEquals(reader.get().getInt(2), MessageType.PUT.getValue(), "First message should be PUT");
285+
assertEquals(reader.get().getInt(3), MessageType.PUT.getValue(), "First message should be PUT");
287286

288287
assertTrue(reader.next(), "Reader should process second message");
289-
assertEquals(reader.get().getInt(2), MessageType.DELETE.getValue(), "Second message should be DELETE");
288+
assertEquals(reader.get().getInt(3), MessageType.DELETE.getValue(), "Second message should be DELETE");
290289

291290
reader.close();
292291
}
@@ -339,7 +338,7 @@ public void testEdgeCaseScenarios() throws IOException {
339338

340339
assertTrue(reader2.next(), "Reader should handle large offset values");
341340
InternalRow largeOffsetRow = reader2.get();
342-
assertEquals(largeOffsetRow.getLong(3), Long.MAX_VALUE, "Large offset should be preserved");
341+
assertEquals(largeOffsetRow.getLong(2), Long.MAX_VALUE, "Large offset should be preserved");
343342
reader2.close();
344343

345344
// Case 3: Null region handling
@@ -349,6 +348,56 @@ public void testEdgeCaseScenarios() throws IOException {
349348
reader3.close();
350349
}
351350

351+
@Test
352+
public void testRawPubsubInternalRowOrdering() throws IOException {
353+
SparkPubSubInputPartitionReader reader = createReaderWithMockIterator();
354+
355+
// Setup: a PUT record with non-trivial values to assert per field
356+
ByteBuffer replicationMetadata = ByteBuffer.wrap("rm-payload".getBytes());
357+
PubSubSplitIterator.PubSubInputRecord record =
358+
createMockPutRecord(123L, "ordering-key", "ordering-value", 42, replicationMetadata, 7);
359+
when(mockSplitIterator.next()).thenReturn(record).thenReturn(null);
360+
361+
assertTrue(reader.next(), "Reader should process message for ordering validation");
362+
InternalRow row = reader.get();
363+
assertNotNull(row, "Row should not be null");
364+
365+
// Field by field ordering matches RAW_PUBSUB_INPUT_TABLE_SCHEMA
366+
// 0: __region__ (StringType / UTF8String)
367+
assertEquals(row.getUTF8String(0).toString(), TEST_REGION, "Region should match");
368+
369+
// 1: __partition__ (IntegerType)
370+
assertEquals(row.getInt(1), TEST_PARTITION_NUMBER, "Partition number should match");
371+
372+
// 2: __offset__ (LongType)
373+
assertEquals(row.getLong(2), 123L, "Offset should match");
374+
375+
// 3: __message_type__ (IntegerType)
376+
assertEquals(row.getInt(3), MessageType.PUT.getValue(), "Message type should be PUT");
377+
378+
// 4: __schema_id__ (IntegerType)
379+
assertEquals(row.getInt(4), 42, "Schema ID should match");
380+
381+
// 5: key (BinaryType)
382+
assertTrue(Arrays.equals(row.getBinary(5), "ordering-key".getBytes()), "Key bytes should match expected value");
383+
384+
// 6: value (BinaryType)
385+
assertTrue(Arrays.equals(row.getBinary(6), "ordering-value".getBytes()), "Value bytes should match expected value");
386+
387+
// 7: __replication_metadata_version_id__ (IntegerType)
388+
assertEquals(row.getInt(7), 7, "Replication metadata version ID should match");
389+
390+
// 8: __replication_metadata_payload__ (BinaryType)
391+
assertTrue(
392+
Arrays.equals(row.getBinary(8), "rm-payload".getBytes()),
393+
"Replication metadata payload bytes should match");
394+
395+
// No more records
396+
assertFalse(reader.next(), "Reader should return false when no more messages");
397+
398+
reader.close();
399+
}
400+
352401
/**
353402
* Helper method to create a SparkPubSubInputPartitionReader with a mocked PubSubSplitIterator.
354403
* Uses the test-only constructor to inject the mock iterator.

0 commit comments

Comments
 (0)