|
45 | 45 | #include "arrow/table.h"
|
46 | 46 | #include "arrow/table_builder.h"
|
47 | 47 | #include "arrow/testing/gtest_util.h"
|
| 48 | +#include "arrow/util/align_util.h" |
48 | 49 | #include "arrow/util/checked_cast.h"
|
49 | 50 | #include "arrow/util/string.h"
|
50 | 51 | #include "arrow/util/value_parsing.h"
|
@@ -281,6 +282,137 @@ class MiddlewareScenario : public Scenario {
|
281 | 282 | std::shared_ptr<TestClientMiddlewareFactory> client_middleware_;
|
282 | 283 | };
|
283 | 284 |
|
| 285 | +/// \brief The server used for testing FlightClient data alignment. |
| 286 | +/// |
| 287 | +/// The server always returns the same data of various byte widths. |
| 288 | +/// The client should return data that is aligned according to the data type |
| 289 | +/// if FlightCallOptions.read_options.ensure_memory_alignment is true. |
| 290 | +/// |
| 291 | +/// This scenario is passed only when the client returns aligned data. |
| 292 | +class AlignmentServer : public FlightServerBase { |
| 293 | + Status GetFlightInfo(const ServerCallContext& context, |
| 294 | + const FlightDescriptor& descriptor, |
| 295 | + std::unique_ptr<FlightInfo>* result) override { |
| 296 | + auto schema = BuildSchema(); |
| 297 | + std::vector<FlightEndpoint> endpoints{ |
| 298 | + FlightEndpoint{{"align-data"}, {}, std::nullopt, ""}}; |
| 299 | + ARROW_ASSIGN_OR_RAISE( |
| 300 | + auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false)); |
| 301 | + *result = std::make_unique<FlightInfo>(info); |
| 302 | + return Status::OK(); |
| 303 | + } |
| 304 | + |
| 305 | + Status DoGet(const ServerCallContext& context, const Ticket& request, |
| 306 | + std::unique_ptr<FlightDataStream>* stream) override { |
| 307 | + if (request.ticket != "align-data") { |
| 308 | + return Status::KeyError("Could not find flight: ", request.ticket); |
| 309 | + } |
| 310 | + auto record_batch = RecordBatchFromJSON(BuildSchema(), R"([ |
| 311 | + [1, 1, false], |
| 312 | + [2, 2, true], |
| 313 | + [3, 3, false] |
| 314 | + ])"); |
| 315 | + std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch}; |
| 316 | + ARROW_ASSIGN_OR_RAISE(auto record_batch_reader, |
| 317 | + RecordBatchReader::Make(record_batches)); |
| 318 | + *stream = std::make_unique<RecordBatchStream>(record_batch_reader); |
| 319 | + return Status::OK(); |
| 320 | + } |
| 321 | + |
| 322 | + private: |
| 323 | + std::shared_ptr<Schema> BuildSchema() { |
| 324 | + return arrow::schema({ |
| 325 | + arrow::field("int32", arrow::int32(), false), |
| 326 | + arrow::field("int64", arrow::int64(), false), |
| 327 | + arrow::field("bool", arrow::boolean(), false), |
| 328 | + }); |
| 329 | + } |
| 330 | +}; |
| 331 | + |
| 332 | +/// \brief The alignment scenario. |
| 333 | +/// |
| 334 | +/// This tests that the client provides aligned data if requested. |
| 335 | +class AlignmentScenario : public Scenario { |
| 336 | + Status MakeServer(std::unique_ptr<FlightServerBase>* server, |
| 337 | + FlightServerOptions* options) override { |
| 338 | + server->reset(new AlignmentServer()); |
| 339 | + return Status::OK(); |
| 340 | + } |
| 341 | + |
| 342 | + Status MakeClient(FlightClientOptions* options) override { return Status::OK(); } |
| 343 | + |
| 344 | + arrow::Result<std::shared_ptr<Table>> GetTable(FlightClient* client, |
| 345 | + const FlightCallOptions& call_options) { |
| 346 | + ARROW_ASSIGN_OR_RAISE(auto info, |
| 347 | + client->GetFlightInfo(FlightDescriptor::Command("alignment"))); |
| 348 | + std::vector<std::shared_ptr<arrow::Table>> tables; |
| 349 | + for (const auto& endpoint : info->endpoints()) { |
| 350 | + if (!endpoint.locations.empty()) { |
| 351 | + std::stringstream ss; |
| 352 | + ss << "["; |
| 353 | + for (const auto& location : endpoint.locations) { |
| 354 | + if (ss.str().size() != 1) { |
| 355 | + ss << ", "; |
| 356 | + } |
| 357 | + ss << location.ToString(); |
| 358 | + } |
| 359 | + ss << "]"; |
| 360 | + return Status::Invalid( |
| 361 | + "Expected to receive empty locations to use the original service: ", |
| 362 | + ss.str()); |
| 363 | + } |
| 364 | + ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(call_options, endpoint.ticket)); |
| 365 | + ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable()); |
| 366 | + tables.push_back(table); |
| 367 | + } |
| 368 | + return ConcatenateTables(tables); |
| 369 | + } |
| 370 | + |
| 371 | + Status RunClient(std::unique_ptr<FlightClient> client) override { |
| 372 | + for (ipc::Alignment ensure_alignment : |
| 373 | + {ipc::Alignment::kAnyAlignment, ipc::Alignment::kDataTypeSpecificAlignment, |
| 374 | + ipc::Alignment::k64ByteAlignment}) { |
| 375 | + auto call_options = FlightCallOptions(); |
| 376 | + call_options.read_options.ensure_alignment = ensure_alignment; |
| 377 | + ARROW_ASSIGN_OR_RAISE(auto table, GetTable(client.get(), call_options)); |
| 378 | + |
| 379 | + // Check read data |
| 380 | + auto expected_row_count = 3; |
| 381 | + if (table->num_rows() != expected_row_count) { |
| 382 | + return Status::Invalid("Read table size isn't expected\n", "Expected rows:\n", |
| 383 | + expected_row_count, "Actual rows:\n", table->num_rows()); |
| 384 | + } |
| 385 | + auto expected_column_count = 3; |
| 386 | + if (table->num_columns() != expected_column_count) { |
| 387 | + return Status::Invalid("Read table size isn't expected\n", "Expected columns:\n", |
| 388 | + expected_column_count, "Actual columns:\n", |
| 389 | + table->num_columns()); |
| 390 | + } |
| 391 | + // Check data alignment |
| 392 | + std::vector<bool> needs_alignment; |
| 393 | + if (ensure_alignment == ipc::Alignment::kAnyAlignment) { |
| 394 | + // this is not a requirement but merely an observation: |
| 395 | + // with ensure_alignment=false, flight client returns mis-aligned data |
| 396 | + // if this is not the case any more, feel free to remove this assertion |
| 397 | + if (util::CheckAlignment(*table, arrow::util::kValueAlignment, |
| 398 | + &needs_alignment)) { |
| 399 | + return Status::Invalid( |
| 400 | + "Read table has aligned data, which is good, but unprecedented"); |
| 401 | + } |
| 402 | + } else { |
| 403 | + // with ensure_alignment != kValueAlignment, we require data to be aligned |
| 404 | + // the value of the Alignment enum provides us with the byte alignment value |
| 405 | + if (!util::CheckAlignment(*table, static_cast<int64_t>(ensure_alignment), |
| 406 | + &needs_alignment)) { |
| 407 | + return Status::Invalid("Read table has unaligned data"); |
| 408 | + } |
| 409 | + } |
| 410 | + } |
| 411 | + |
| 412 | + return Status::OK(); |
| 413 | + } |
| 414 | +}; |
| 415 | + |
284 | 416 | /// \brief The server used for testing FlightInfo.ordered.
|
285 | 417 | ///
|
286 | 418 | /// If the given command is "ordered", the server sets
|
@@ -316,25 +448,16 @@ class OrderedServer : public FlightServerBase {
|
316 | 448 |
|
317 | 449 | Status DoGet(const ServerCallContext& context, const Ticket& request,
|
318 | 450 | std::unique_ptr<FlightDataStream>* stream) override {
|
319 |
| - ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make( |
320 |
| - BuildSchema(), arrow::default_memory_pool())); |
321 |
| - auto number_builder = builder->GetFieldAs<Int32Builder>(0); |
| 451 | + std::shared_ptr<RecordBatch> record_batch; |
322 | 452 | if (request.ticket == "1") {
|
323 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(1)); |
324 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(2)); |
325 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(3)); |
| 453 | + record_batch = RecordBatchFromJSON(BuildSchema(), "[[1], [2], [3]]"); |
326 | 454 | } else if (request.ticket == "2") {
|
327 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(10)); |
328 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(20)); |
329 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(30)); |
| 455 | + record_batch = RecordBatchFromJSON(BuildSchema(), "[[10], [20], [30]]"); |
330 | 456 | } else if (request.ticket == "3") {
|
331 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(100)); |
332 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(200)); |
333 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(300)); |
| 457 | + record_batch = RecordBatchFromJSON(BuildSchema(), "[[100], [200], [300]]"); |
334 | 458 | } else {
|
335 | 459 | return Status::KeyError("Could not find flight: ", request.ticket);
|
336 | 460 | }
|
337 |
| - ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush()); |
338 | 461 | std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
|
339 | 462 | ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
|
340 | 463 | RecordBatchReader::Make(record_batches));
|
@@ -390,19 +513,9 @@ class OrderedScenario : public Scenario {
|
390 | 513 |
|
391 | 514 | // Build expected table
|
392 | 515 | auto schema = arrow::schema({arrow::field("number", arrow::int32(), false)});
|
393 |
| - ARROW_ASSIGN_OR_RAISE(auto builder, |
394 |
| - RecordBatchBuilder::Make(schema, arrow::default_memory_pool())); |
395 |
| - auto number_builder = builder->GetFieldAs<Int32Builder>(0); |
396 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(1)); |
397 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(2)); |
398 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(3)); |
399 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(10)); |
400 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(20)); |
401 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(30)); |
402 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(100)); |
403 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(200)); |
404 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(300)); |
405 |
| - ARROW_ASSIGN_OR_RAISE(auto expected_record_batch, builder->Flush()); |
| 516 | + auto expected_record_batch = RecordBatchFromJSON(schema, R"([ |
| 517 | + [1], [2], [3], [10], [20], [30], [100], [200], [300] |
| 518 | + ])"); |
406 | 519 | std::vector<std::shared_ptr<RecordBatch>> expected_record_batches{
|
407 | 520 | expected_record_batch};
|
408 | 521 | ARROW_ASSIGN_OR_RAISE(auto expected_table,
|
@@ -490,11 +603,8 @@ class ExpirationTimeServer : public FlightServerBase {
|
490 | 603 | }
|
491 | 604 | }
|
492 | 605 | status.num_gets++;
|
493 |
| - ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make( |
494 |
| - BuildSchema(), arrow::default_memory_pool())); |
495 |
| - auto number_builder = builder->GetFieldAs<UInt32Builder>(0); |
496 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(index)); |
497 |
| - ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush()); |
| 606 | + auto record_batch = |
| 607 | + RecordBatchFromJSON(BuildSchema(), "[[" + std::to_string(index) + "]]"); |
498 | 608 | std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
|
499 | 609 | ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
|
500 | 610 | RecordBatchReader::Make(record_batches));
|
@@ -621,13 +731,7 @@ class ExpirationTimeDoGetScenario : public Scenario {
|
621 | 731 |
|
622 | 732 | // Build expected table
|
623 | 733 | auto schema = arrow::schema({arrow::field("number", arrow::uint32(), false)});
|
624 |
| - ARROW_ASSIGN_OR_RAISE(auto builder, |
625 |
| - RecordBatchBuilder::Make(schema, arrow::default_memory_pool())); |
626 |
| - auto number_builder = builder->GetFieldAs<UInt32Builder>(0); |
627 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(0)); |
628 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(1)); |
629 |
| - ARROW_RETURN_NOT_OK(number_builder->Append(2)); |
630 |
| - ARROW_ASSIGN_OR_RAISE(auto expected_record_batch, builder->Flush()); |
| 734 | + auto expected_record_batch = RecordBatchFromJSON(schema, "[[0], [1], [2]]"); |
631 | 735 | std::vector<std::shared_ptr<RecordBatch>> expected_record_batches{
|
632 | 736 | expected_record_batch};
|
633 | 737 | ARROW_ASSIGN_OR_RAISE(auto expected_table,
|
@@ -2382,6 +2486,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>*
|
2382 | 2486 | } else if (scenario_name == "middleware") {
|
2383 | 2487 | *out = std::make_shared<MiddlewareScenario>();
|
2384 | 2488 | return Status::OK();
|
| 2489 | + } else if (scenario_name == "alignment") { |
| 2490 | + *out = std::make_shared<AlignmentScenario>(); |
| 2491 | + return Status::OK(); |
2385 | 2492 | } else if (scenario_name == "ordered") {
|
2386 | 2493 | *out = std::make_shared<OrderedScenario>();
|
2387 | 2494 | return Status::OK();
|
|
0 commit comments