Skip to content

Commit 8b75373

Browse files
apacheGH-44923: [MATLAB] Add IPC RecordBatchStreamReader MATLAB class (apache#45068)
### Rationale for this change To enable support for the IPC Streaming format in the MATLAB interface, we should add a `RecordBatchStreamReader` class. This is a followup to apache#44922 ### What changes are included in this PR? 1. Added a new `arrow.io.ipc.RecordBatchStreamReader` MATLAB class. ### Are these changes tested? Yes. 1. Added new MATLAB test suite `arrow/matlab/test/arrow/io/ipc/tRecordBatchStreamReader.m`. ### Are there any user-facing changes? Yes. 1. Users can now create `arrow.io.ipc.RecordBatchStreamReader` objects to read `RecordBatch` objects incrementally from an Arrow IPC Stream file. ### Notes 1. Thank you @ sgilmore10 for your help with this pull request! * GitHub Issue: apache#44923 Lead-authored-by: Kevin Gurney <[email protected]> Co-authored-by: Kevin Gurney <[email protected]> Co-authored-by: Sarah Gilmore <[email protected]> Signed-off-by: Kevin Gurney <[email protected]>
1 parent 035e331 commit 8b75373

File tree

7 files changed

+622
-0
lines changed

7 files changed

+622
-0
lines changed

matlab/src/cpp/arrow/matlab/error/error.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,5 +249,7 @@ static const char* IPC_RECORD_BATCH_READER_OPEN_FAILED =
249249
"arrow:io:ipc:FailedToOpenRecordBatchReader";
250250
static const char* IPC_RECORD_BATCH_READ_INVALID_INDEX = "arrow:io:ipc:InvalidIndex";
251251
static const char* IPC_RECORD_BATCH_READ_FAILED = "arrow:io:ipc:ReadFailed";
252+
static const char* IPC_TABLE_READ_FAILED = "arrow:io:ipc:TableReadFailed";
253+
static const char* IPC_END_OF_STREAM = "arrow:io:ipc:EndOfStream";
252254

253255
} // namespace arrow::matlab::error
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow/matlab/io/ipc/proxy/record_batch_stream_reader.h"
19+
#include "arrow/io/file.h"
20+
#include "arrow/matlab/error/error.h"
21+
#include "arrow/matlab/tabular/proxy/record_batch.h"
22+
#include "arrow/matlab/tabular/proxy/schema.h"
23+
#include "arrow/matlab/tabular/proxy/table.h"
24+
#include "arrow/util/utf8.h"
25+
26+
#include "libmexclass/proxy/ProxyManager.h"
27+
28+
namespace arrow::matlab::io::ipc::proxy {
29+
30+
RecordBatchStreamReader::RecordBatchStreamReader(
31+
const std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader)
32+
: reader{std::move(reader)} {
33+
REGISTER_METHOD(RecordBatchStreamReader, getSchema);
34+
REGISTER_METHOD(RecordBatchStreamReader, readRecordBatch);
35+
REGISTER_METHOD(RecordBatchStreamReader, hasNextRecordBatch);
36+
REGISTER_METHOD(RecordBatchStreamReader, readTable);
37+
}
38+
39+
libmexclass::proxy::MakeResult RecordBatchStreamReader::make(
40+
const libmexclass::proxy::FunctionArguments& constructor_arguments) {
41+
namespace mda = ::matlab::data;
42+
using RecordBatchStreamReaderProxy =
43+
arrow::matlab::io::ipc::proxy::RecordBatchStreamReader;
44+
45+
const mda::StructArray opts = constructor_arguments[0];
46+
47+
const mda::StringArray filename_mda = opts[0]["Filename"];
48+
const auto filename_utf16 = std::u16string(filename_mda[0]);
49+
MATLAB_ASSIGN_OR_ERROR(const auto filename_utf8,
50+
arrow::util::UTF16StringToUTF8(filename_utf16),
51+
error::UNICODE_CONVERSION_ERROR_ID);
52+
53+
MATLAB_ASSIGN_OR_ERROR(auto input_stream, arrow::io::ReadableFile::Open(filename_utf8),
54+
error::FAILED_TO_OPEN_FILE_FOR_READ);
55+
56+
MATLAB_ASSIGN_OR_ERROR(auto reader,
57+
arrow::ipc::RecordBatchStreamReader::Open(input_stream),
58+
error::IPC_RECORD_BATCH_READER_OPEN_FAILED);
59+
60+
return std::make_shared<RecordBatchStreamReaderProxy>(std::move(reader));
61+
}
62+
63+
void RecordBatchStreamReader::getSchema(libmexclass::proxy::method::Context& context) {
64+
namespace mda = ::matlab::data;
65+
using SchemaProxy = arrow::matlab::tabular::proxy::Schema;
66+
67+
auto schema = reader->schema();
68+
69+
auto schema_proxy = std::make_shared<SchemaProxy>(std::move(schema));
70+
const auto schema_proxy_id =
71+
libmexclass::proxy::ProxyManager::manageProxy(schema_proxy);
72+
73+
mda::ArrayFactory factory;
74+
const auto schema_proxy_id_mda = factory.createScalar(schema_proxy_id);
75+
context.outputs[0] = schema_proxy_id_mda;
76+
}
77+
78+
void RecordBatchStreamReader::readTable(libmexclass::proxy::method::Context& context) {
79+
namespace mda = ::matlab::data;
80+
using TableProxy = arrow::matlab::tabular::proxy::Table;
81+
82+
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto table, reader->ToTable(), context,
83+
error::IPC_TABLE_READ_FAILED);
84+
auto table_proxy = std::make_shared<TableProxy>(table);
85+
const auto table_proxy_id = libmexclass::proxy::ProxyManager::manageProxy(table_proxy);
86+
87+
mda::ArrayFactory factory;
88+
const auto table_proxy_id_mda = factory.createScalar(table_proxy_id);
89+
context.outputs[0] = table_proxy_id_mda;
90+
}
91+
92+
void RecordBatchStreamReader::readRecordBatch(
93+
libmexclass::proxy::method::Context& context) {
94+
namespace mda = ::matlab::data;
95+
using RecordBatchProxy = arrow::matlab::tabular::proxy::RecordBatch;
96+
using namespace libmexclass::error;
97+
// If we don't have a "pre-cached" record batch to return, then try reading another
98+
// record batch from the IPC Stream. If there are no more record batches in the stream,
99+
// then error.
100+
if (!nextRecordBatch) {
101+
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(nextRecordBatch, reader->Next(), context,
102+
error::IPC_RECORD_BATCH_READ_FAILED);
103+
}
104+
// Even if the read was "successful", the resulting record batch may be empty,
105+
// signaling the end of the stream.
106+
if (!nextRecordBatch) {
107+
context.error =
108+
Error{error::IPC_END_OF_STREAM,
109+
"Reached end of Arrow IPC Stream. No more record batches to read."};
110+
return;
111+
}
112+
auto record_batch_proxy = std::make_shared<RecordBatchProxy>(nextRecordBatch);
113+
const auto record_batch_proxy_id =
114+
libmexclass::proxy::ProxyManager::manageProxy(record_batch_proxy);
115+
// Once we have "consumed" the next RecordBatch, set nextRecordBatch to nullptr
116+
// so that the next call to hasNextRecordBatch correctly checks whether there are more
117+
// record batches remaining in the IPC Stream.
118+
nextRecordBatch = nullptr;
119+
mda::ArrayFactory factory;
120+
const auto record_batch_proxy_id_mda = factory.createScalar(record_batch_proxy_id);
121+
context.outputs[0] = record_batch_proxy_id_mda;
122+
}
123+
124+
void RecordBatchStreamReader::hasNextRecordBatch(
125+
libmexclass::proxy::method::Context& context) {
126+
namespace mda = ::matlab::data;
127+
bool has_next_record_batch = true;
128+
if (!nextRecordBatch) {
129+
// Try to read another RecordBatch from the
130+
// IPC Stream.
131+
auto maybe_record_batch = reader->Next();
132+
if (!maybe_record_batch.ok()) {
133+
has_next_record_batch = false;
134+
} else {
135+
// If we read a RecordBatch successfully,
136+
// then "cache" the RecordBatch
137+
// so that we can return it on the next
138+
// call to readRecordBatch.
139+
nextRecordBatch = *maybe_record_batch;
140+
141+
// Even if the read was "successful", the resulting
142+
// record batch may be empty, signaling that
143+
// the end of the IPC stream has been reached.
144+
if (!nextRecordBatch) {
145+
has_next_record_batch = false;
146+
}
147+
}
148+
}
149+
150+
mda::ArrayFactory factory;
151+
context.outputs[0] = factory.createScalar(has_next_record_batch);
152+
}
153+
154+
} // namespace arrow::matlab::io::ipc::proxy
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#pragma once
19+
20+
#include "arrow/ipc/reader.h"
21+
#include "libmexclass/proxy/Proxy.h"
22+
23+
namespace arrow::matlab::io::ipc::proxy {
24+
25+
class RecordBatchStreamReader : public libmexclass::proxy::Proxy {
26+
public:
27+
RecordBatchStreamReader(std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader);
28+
29+
~RecordBatchStreamReader() = default;
30+
31+
static libmexclass::proxy::MakeResult make(
32+
const libmexclass::proxy::FunctionArguments& constructor_arguments);
33+
34+
protected:
35+
std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader;
36+
std::shared_ptr<arrow::RecordBatch> nextRecordBatch;
37+
38+
void getSchema(libmexclass::proxy::method::Context& context);
39+
void readRecordBatch(libmexclass::proxy::method::Context& context);
40+
void hasNextRecordBatch(libmexclass::proxy::method::Context& context);
41+
void readTable(libmexclass::proxy::method::Context& context);
42+
};
43+
44+
} // namespace arrow::matlab::io::ipc::proxy

matlab/src/cpp/arrow/matlab/proxy/factory.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "arrow/matlab/io/feather/proxy/writer.h"
3737
#include "arrow/matlab/io/ipc/proxy/record_batch_file_reader.h"
3838
#include "arrow/matlab/io/ipc/proxy/record_batch_file_writer.h"
39+
#include "arrow/matlab/io/ipc/proxy/record_batch_stream_reader.h"
3940
#include "arrow/matlab/io/ipc/proxy/record_batch_stream_writer.h"
4041
#include "arrow/matlab/tabular/proxy/record_batch.h"
4142
#include "arrow/matlab/tabular/proxy/schema.h"
@@ -113,6 +114,7 @@ libmexclass::proxy::MakeResult Factory::make_proxy(
113114
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchFileReader , arrow::matlab::io::ipc::proxy::RecordBatchFileReader);
114115
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchFileWriter , arrow::matlab::io::ipc::proxy::RecordBatchFileWriter);
115116
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchStreamWriter , arrow::matlab::io::ipc::proxy::RecordBatchStreamWriter);
117+
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchStreamReader , arrow::matlab::io::ipc::proxy::RecordBatchStreamReader);
116118

117119
// clang-format on
118120

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
%RECORDBATCHSTREAMREADER Class for reading Arrow record batches from the
2+
% Arrow IPC Stream format.
3+
4+
% Licensed to the Apache Software Foundation (ASF) under one or more
5+
% contributor license agreements. See the NOTICE file distributed with
6+
% this work for additional information regarding copyright ownership.
7+
% The ASF licenses this file to you under the Apache License, Version
8+
% 2.0 (the "License"); you may not use this file except in compliance
9+
% with the License. You may obtain a copy of the License at
10+
%
11+
% http://www.apache.org/licenses/LICENSE-2.0
12+
%
13+
% Unless required by applicable law or agreed to in writing, software
14+
% distributed under the License is distributed on an "AS IS" BASIS,
15+
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
16+
% implied. See the License for the specific language governing
17+
% permissions and limitations under the License.
18+
19+
classdef RecordBatchStreamReader < matlab.mixin.Scalar
20+
21+
properties(SetAccess=private, GetAccess=public, Hidden)
22+
Proxy
23+
end
24+
25+
properties (Dependent, SetAccess=private, GetAccess=public)
26+
Schema
27+
end
28+
29+
methods
30+
function obj = RecordBatchStreamReader(filename)
31+
arguments
32+
filename(1, 1) string {mustBeNonzeroLengthText}
33+
end
34+
args = struct(Filename=filename);
35+
proxyName = "arrow.io.ipc.proxy.RecordBatchStreamReader";
36+
obj.Proxy = arrow.internal.proxy.create(proxyName, args);
37+
end
38+
39+
function schema = get.Schema(obj)
40+
proxyID = obj.Proxy.getSchema();
41+
proxyName = "arrow.tabular.proxy.Schema";
42+
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
43+
schema = arrow.tabular.Schema(proxy);
44+
end
45+
46+
function tf = hasnext(obj)
47+
tf = obj.Proxy.hasNextRecordBatch();
48+
end
49+
50+
function tf = done(obj)
51+
tf = ~obj.Proxy.hasNextRecordBatch();
52+
end
53+
54+
function arrowRecordBatch = read(obj)
55+
% NOTE: This function is a "convenience alias" for the readRecordBatch
56+
% method, which has a longer name. This is the exact same implementation
57+
% as readRecordBatch. Since this method might be called in a tight loop,
58+
% it should be slightly more efficient to call the C++ code directly,
59+
% rather than invoking obj.readRecordBatch indirectly. We are intentionally
60+
% trading off code duplication for performance here.
61+
proxyID = obj.Proxy.readRecordBatch();
62+
proxyName = "arrow.tabular.proxy.RecordBatch";
63+
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
64+
arrowRecordBatch = arrow.tabular.RecordBatch(proxy);
65+
end
66+
67+
function arrowRecordBatch = readRecordBatch(obj)
68+
proxyID = obj.Proxy.readRecordBatch();
69+
proxyName = "arrow.tabular.proxy.RecordBatch";
70+
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
71+
arrowRecordBatch = arrow.tabular.RecordBatch(proxy);
72+
end
73+
74+
function arrowTable = readTable(obj)
75+
proxyID = obj.Proxy.readTable();
76+
proxyName = "arrow.tabular.proxy.Table";
77+
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
78+
arrowTable = arrow.tabular.Table(proxy);
79+
end
80+
81+
end
82+
83+
end

0 commit comments

Comments
 (0)