Skip to content

Commit f5b3fc7

Browse files
authored
GH-48076: [C++][Flight] fix GeneratorStream for Tables (#48082)
### Rationale for this change After the changes in #47115, GeneratorStreams backed by anything else than RecordBatches failed. This includes Tables and RecordBatchReaders. This was caused by a too strict assumption that the RecordBatchStream#GetSchemaPayload would always get called, which is not the case when the GeneratorStream is backed by a Table or a RecordBatchReader. ### What changes are included in this PR? Removal of the problematic assertion and initialization of the writer object when it is needed first. Also, to accommodate for this case, drop the incoming message when initializing the writer in Next, as the message there is of the SCHEMA type and we want RECORD_BATCH or DICTIONARY_BATCH one. ### Are these changes tested? Yes, via CI. Tests for the GeneratorStreams were extended so that they test GeneratorStreams backed by Tables and RecordBatchReaders, not just RecordBatches. ### Are there any user-facing changes? No, just a fix for a regression restoring the functionality from version 21.0.0 and earlier. * GitHub Issue: #48076 Authored-by: Dan Homola <[email protected]> Signed-off-by: David Li <[email protected]>
1 parent 94b9bb6 commit f5b3fc7

File tree

2 files changed

+149
-9
lines changed

2 files changed

+149
-9
lines changed

cpp/src/arrow/flight/server.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
285285

286286
Status GetSchemaPayload(FlightPayload* payload) {
287287
if (!writer_) {
288-
// Create the IPC writer on first call
289-
auto payload_writer =
290-
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
291-
ARROW_ASSIGN_OR_RAISE(
292-
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
293-
reader_->schema(), options_));
288+
RETURN_NOT_OK(InitializeWriter());
294289
}
295290

296291
// Return the expected schema payload.
@@ -317,8 +312,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
317312
return Status::OK();
318313
}
319314
if (!writer_) {
320-
return Status::UnknownError(
321-
"Writer should be initialized before reading Next batches");
315+
RETURN_NOT_OK(InitializeWriter());
316+
// If the writer has not been initialized yet, the first batch in the payload
317+
// queue is going to be a SCHEMA one. In this context, that is
318+
// unexpected, so drop it from the queue so that there is a RECORD_BATCH
319+
// message on the top (same as would be if the writer had been initialized
320+
// in GetSchemaPayload).
321+
if (payload_deque_.front().ipc_message.type == ipc::MessageType::SCHEMA) {
322+
payload_deque_.pop_front();
323+
}
322324
}
323325
// One WriteRecordBatch call might generate multiple payloads, so we
324326
// need to collect them in a deque.
@@ -370,6 +372,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
370372
ipc::IpcWriteOptions options_;
371373
std::unique_ptr<ipc::RecordBatchWriter> writer_;
372374
std::deque<FlightPayload> payload_deque_;
375+
376+
Status InitializeWriter() {
377+
auto payload_writer =
378+
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
379+
ARROW_ASSIGN_OR_RAISE(
380+
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
381+
reader_->schema(), options_));
382+
return Status::OK();
383+
}
373384
};
374385

375386
FlightMetadataWriter::~FlightMetadataWriter() = default;

python/pyarrow/tests/test_flight.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,40 @@ def do_action(self, context, action):
246246
raise NotImplementedError
247247

248248

249+
class EchoTableStreamFlightServer(EchoFlightServer):
250+
"""An echo server that streams the whole table."""
251+
252+
def do_get(self, context, ticket):
253+
return flight.GeneratorStream(
254+
self.last_message.schema,
255+
[self.last_message])
256+
257+
def list_actions(self, context):
258+
return []
259+
260+
def do_action(self, context, action):
261+
if action.type == "who-am-i":
262+
return [context.peer_identity(), context.peer().encode("utf-8")]
263+
raise NotImplementedError
264+
265+
266+
class EchoRecordBatchReaderStreamFlightServer(EchoFlightServer):
267+
"""An echo server that streams the whole table as a RecordBatchReader."""
268+
269+
def do_get(self, context, ticket):
270+
return flight.GeneratorStream(
271+
self.last_message.schema,
272+
[self.last_message.to_reader()])
273+
274+
def list_actions(self, context):
275+
return []
276+
277+
def do_action(self, context, action):
278+
if action.type == "who-am-i":
279+
return [context.peer_identity(), context.peer().encode("utf-8")]
280+
raise NotImplementedError
281+
282+
249283
class GetInfoFlightServer(FlightServerBase):
250284
"""A Flight server that tests GetFlightInfo."""
251285

@@ -1362,7 +1396,7 @@ def test_flight_large_message():
13621396
assert result.equals(data)
13631397

13641398

1365-
def test_flight_generator_stream():
1399+
def test_flight_generator_stream_of_batches():
13661400
"""Try downloading a flight of RecordBatches in a GeneratorStream."""
13671401
data = pa.Table.from_arrays([
13681402
pa.array(range(0, 10 * 1024))
@@ -1378,6 +1412,101 @@ def test_flight_generator_stream():
13781412
assert result.equals(data)
13791413

13801414

1415+
def test_flight_generator_stream_of_batches_with_dict():
1416+
"""
1417+
Try downloading a flight of RecordBatches with dictionaries
1418+
in a GeneratorStream.
1419+
"""
1420+
data = pa.Table.from_arrays([
1421+
pa.array(["foo", "bar", "baz", "foo", "foo"],
1422+
pa.dictionary(pa.int64(), pa.utf8())),
1423+
pa.array([123, 234, 345, 456, 567])
1424+
], names=['a', 'b'])
1425+
1426+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1427+
FlightClient(('localhost', server.port)) as client:
1428+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1429+
data.schema)
1430+
writer.write_table(data)
1431+
writer.close()
1432+
result = client.do_get(flight.Ticket(b'')).read_all()
1433+
assert result.equals(data)
1434+
1435+
1436+
def test_flight_generator_stream_of_table():
1437+
"""Try downloading a flight of Table in a GeneratorStream."""
1438+
data = pa.Table.from_arrays([
1439+
pa.array(range(0, 10 * 1024))
1440+
], names=['a'])
1441+
1442+
with EchoTableStreamFlightServer() as server, \
1443+
FlightClient(('localhost', server.port)) as client:
1444+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1445+
data.schema)
1446+
writer.write_table(data)
1447+
writer.close()
1448+
result = client.do_get(flight.Ticket(b'')).read_all()
1449+
assert result.equals(data)
1450+
1451+
1452+
def test_flight_generator_stream_of_table_with_dict():
1453+
"""
1454+
Try downloading a flight of Table with dictionaries
1455+
in a GeneratorStream.
1456+
"""
1457+
data = pa.Table.from_arrays([
1458+
pa.array(["foo", "bar", "baz", "foo", "foo"],
1459+
pa.dictionary(pa.int64(), pa.utf8())),
1460+
pa.array([123, 234, 345, 456, 567])
1461+
], names=['a', 'b'])
1462+
1463+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1464+
FlightClient(('localhost', server.port)) as client:
1465+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1466+
data.schema)
1467+
writer.write_table(data)
1468+
writer.close()
1469+
result = client.do_get(flight.Ticket(b'')).read_all()
1470+
assert result.equals(data)
1471+
1472+
1473+
def test_flight_generator_stream_of_record_batch_reader():
1474+
"""Try downloading a flight of RecordBatchReader in a GeneratorStream."""
1475+
data = pa.Table.from_arrays([
1476+
pa.array(range(0, 10 * 1024))
1477+
], names=['a'])
1478+
1479+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1480+
FlightClient(('localhost', server.port)) as client:
1481+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1482+
data.schema)
1483+
writer.write_table(data)
1484+
writer.close()
1485+
result = client.do_get(flight.Ticket(b'')).read_all()
1486+
assert result.equals(data)
1487+
1488+
1489+
def test_flight_generator_stream_of_record_batch_reader_with_dict():
1490+
"""
1491+
Try downloading a flight of RecordBatchReader with dictionaries
1492+
in a GeneratorStream.
1493+
"""
1494+
data = pa.Table.from_arrays([
1495+
pa.array(["foo", "bar", "baz", "foo", "foo"],
1496+
pa.dictionary(pa.int64(), pa.utf8())),
1497+
pa.array([123, 234, 345, 456, 567])
1498+
], names=['a', 'b'])
1499+
1500+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1501+
FlightClient(('localhost', server.port)) as client:
1502+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1503+
data.schema)
1504+
writer.write_table(data)
1505+
writer.close()
1506+
result = client.do_get(flight.Ticket(b'')).read_all()
1507+
assert result.equals(data)
1508+
1509+
13811510
def test_flight_invalid_generator_stream():
13821511
"""Try streaming data with mismatched schemas."""
13831512
with InvalidStreamFlightServer() as server, \

0 commit comments

Comments
 (0)