Skip to content

Commit 395a316

Browse files
committed
Add arrow adapter to stream record batches zero-copy into csp
Signed-off-by: Arham Chopra <[email protected]>
1 parent d22137c commit 395a316

File tree

5 files changed

+682
-2
lines changed

5 files changed

+682
-2
lines changed

cpp/csp/python/ArrowInputAdapter.h

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
#ifndef _IN_CSP_ENGINE_ARROWINPUTADAPTER_H
2+
#define _IN_CSP_ENGINE_ARROWINPUTADAPTER_H
3+
4+
#include <csp/engine/PullInputAdapter.h>
5+
#include <csp/python/PyObjectPtr.h>
6+
#include <csp/python/Conversions.h>
7+
#include <Python.h>
8+
9+
#include <arrow/array.h>
10+
#include <arrow/c/abi.h>
11+
#include <arrow/c/bridge.h>
12+
#include <arrow/type.h>
13+
#include <arrow/table.h>
14+
15+
#include <memory>
16+
#include <string>
17+
18+
namespace csp::python::arrow
19+
{
20+
21+
class RecordBatchIterator
22+
{
23+
public:
24+
RecordBatchIterator( PyObject * iter, PyObject * py_schema ): m_iter( PyObjectPtr::incref( iter ) )
25+
{
26+
// Extract the arrow schema
27+
struct ArrowSchema * c_schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( py_schema, "arrow_schema" ) );
28+
auto result = ::arrow::ImportSchema( c_schema );
29+
if( !result.ok() )
30+
CSP_THROW( ValueError, "Failed to load schema for record batches through the PyCapsule C Data interface: " << result.status().ToString() );
31+
m_schema = std::move(result).ValueUnsafe();
32+
}
33+
34+
std::shared_ptr<::arrow::RecordBatch> next()
35+
{
36+
auto py_tuple = csp::python::PyObjectPtr::own( PyIter_Next( m_iter.get() ) );
37+
if( py_tuple.get() == NULL )
38+
{
39+
// No more data in the input steam
40+
return nullptr;
41+
}
42+
else
43+
{
44+
// Extract the record batch
45+
PyObject * py_array = PyTuple_GET_ITEM( py_tuple.get(), 1 );
46+
struct ArrowArray * c_array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( py_array, "arrow_array" ) );
47+
auto result = ::arrow::ImportRecordBatch( c_array, m_schema );
48+
if( !result.ok() )
49+
CSP_THROW( ValueError, "Failed to load record batches through PyCapsule C Data interface: " << result.status().ToString() );
50+
return std::move(result).ValueUnsafe();
51+
}
52+
}
53+
54+
private:
55+
PyObjectPtr m_iter;
56+
std::shared_ptr<::arrow::Schema> m_schema;
57+
};
58+
59+
void ReleaseArrowSchemaPyCapsule( PyObject * capsule ) {
60+
struct ArrowSchema * schema = reinterpret_cast<struct ArrowSchema*>( PyCapsule_GetPointer( capsule, "arrow_schema" ) );
61+
if ( schema -> release != NULL )
62+
{
63+
schema -> release( schema );
64+
}
65+
free( schema );
66+
}
67+
68+
void ReleaseArrowArrayPyCapsule( PyObject * capsule ) {
69+
struct ArrowArray * array = reinterpret_cast<struct ArrowArray*>( PyCapsule_GetPointer( capsule, "arrow_array" ) );
70+
if ( array -> release != NULL ) {
71+
array -> release( array );
72+
}
73+
free( array );
74+
}
75+
76+
class RecordBatchInputAdapter: public PullInputAdapter<std::vector<DialectGenericType>>
77+
{
78+
public:
79+
RecordBatchInputAdapter( Engine * engine, CspTypePtr & type, std::string tsColName, RecordBatchIterator source, int expectSmallBatches )
80+
: PullInputAdapter<std::vector<DialectGenericType>>( engine, type, PushMode::LAST_VALUE ),
81+
m_tsColName( tsColName ),
82+
m_source( source ),
83+
m_expectSmallBatches( expectSmallBatches != 0 ),
84+
m_finished( false )
85+
{
86+
}
87+
88+
long long findFirstMatchingIndex( DateTime time )
89+
{
90+
// Find the first index with time equal or greater than `time`
91+
auto m_numRows = m_tsArray -> length();
92+
auto start_time = ( time.asNanoseconds() % m_multiplier == 0 ) ? time.asNanoseconds()/m_multiplier : time.asNanoseconds()/m_multiplier + 1;
93+
94+
auto first_time = m_tsArray -> Value( 0 );
95+
if( first_time >= start_time )
96+
{
97+
return 0;
98+
}
99+
100+
auto last_time = m_tsArray -> Value( m_numRows - 1 );
101+
if( last_time < start_time )
102+
{
103+
return -1;
104+
}
105+
106+
auto l = 0;
107+
auto r = m_numRows-1;
108+
auto mid = 0;
109+
while( l <= r )
110+
{
111+
mid = (l + r) / 2;
112+
auto mid_time = m_tsArray -> Value( mid );
113+
if( mid_time < start_time )
114+
{
115+
auto mid_next_time = m_tsArray -> Value( mid + 1 );
116+
if( mid_next_time >= start_time )
117+
{
118+
break;
119+
}
120+
else
121+
{
122+
l = mid+1;
123+
}
124+
}
125+
else if ( mid_time > start_time )
126+
{
127+
r = mid - 1;
128+
}
129+
}
130+
return mid+1;
131+
}
132+
133+
134+
long long findNextLargerTimestampIndex( long int start_idx )
135+
{
136+
// Find the first index with time just greater than the time at start_idx
137+
long long res = 0;
138+
auto cur_time = m_tsArray -> Value( start_idx );
139+
if( m_expectSmallBatches )
140+
{
141+
auto idx = start_idx + 1;
142+
while( idx < m_numRows && m_tsArray -> Value( idx ) == cur_time )
143+
{
144+
idx++;
145+
}
146+
res = idx;
147+
}
148+
else
149+
{
150+
auto last_time = m_tsArray -> Value( m_numRows - 1 );
151+
if( last_time == cur_time )
152+
{
153+
return m_numRows;
154+
}
155+
156+
auto l = start_idx;
157+
auto r = m_numRows-1;
158+
auto mid = 0;
159+
while( l <= r )
160+
{
161+
mid = (l + r) / 2;
162+
auto mid_time = m_tsArray -> Value( mid );
163+
if( mid_time == cur_time )
164+
{
165+
auto mid_next_time = m_tsArray -> Value( mid + 1 );
166+
if( mid_next_time > cur_time )
167+
{
168+
break;
169+
}
170+
else
171+
{
172+
l = mid+1;
173+
}
174+
}
175+
else if ( mid_time > cur_time )
176+
{
177+
r = mid - 1;
178+
}
179+
}
180+
res = mid+1;
181+
}
182+
return res;
183+
}
184+
185+
void start( DateTime start, DateTime end ) override
186+
{
187+
// Find the starting index where time >= start
188+
m_endTime = end.asNanoseconds();
189+
bool reachedStartTime = false;
190+
while( !reachedStartTime and !m_finished )
191+
{
192+
m_curRecordBatch = getNonEmptyRecordBatchFromSource();
193+
if( !m_curRecordBatch )
194+
{
195+
m_finished = true;
196+
continue;
197+
}
198+
auto schema = m_curRecordBatch -> schema();
199+
auto tsField = schema -> GetFieldByName( m_tsColName );
200+
auto timestampType = std::static_pointer_cast<::arrow::TimestampType>( tsField -> type() );
201+
auto array = m_curRecordBatch -> GetColumnByName( m_tsColName );
202+
if( !array )
203+
{
204+
m_finished = true;
205+
continue;
206+
}
207+
208+
m_tsArray = std::static_pointer_cast<::arrow::TimestampArray>( array );
209+
m_numRows = m_tsArray -> length();
210+
211+
switch( timestampType -> unit() )
212+
{
213+
case ::arrow::TimeUnit::SECOND:
214+
{
215+
m_multiplier = 1000000000;
216+
break;
217+
}
218+
case ::arrow::TimeUnit::MILLI:
219+
{
220+
m_multiplier = 1000000;
221+
break;
222+
}
223+
case ::arrow::TimeUnit::MICRO:
224+
{
225+
m_multiplier = 1000;
226+
break;
227+
}
228+
case ::arrow::TimeUnit::NANO:
229+
{
230+
m_multiplier = 1;
231+
break;
232+
}
233+
default:
234+
{
235+
CSP_THROW( ValueError, "Unsupported unit type for arrow timestamp column" );
236+
}
237+
}
238+
m_curBatchIdx = findFirstMatchingIndex( start );
239+
if( m_curBatchIdx >= 0 )
240+
{
241+
break;
242+
}
243+
}
244+
PullInputAdapter<std::vector<DialectGenericType>>::start( start, end );
245+
}
246+
247+
std::shared_ptr<::arrow::RecordBatch> getNonEmptyRecordBatchFromSource()
248+
{
249+
std::shared_ptr<::arrow::RecordBatch> rb;
250+
while( ( rb = m_source.next() ) && ( rb -> num_rows() == 0) ) { continue; }
251+
return rb;
252+
}
253+
254+
DialectGenericType convertRecordBatchToPython( std::shared_ptr<::arrow::RecordBatch> rb )
255+
{
256+
struct ArrowSchema* rb_schema = ( struct ArrowSchema* )malloc( sizeof( struct ArrowSchema ) );
257+
struct ArrowArray* rb_array = ( struct ArrowArray* )malloc( sizeof( struct ArrowArray ) );
258+
::arrow::Status st = ::arrow::ExportRecordBatch( *rb, rb_array, rb_schema );
259+
auto py_schema = csp::python::PyObjectPtr::own( PyCapsule_New( rb_schema, "arrow_schema", ReleaseArrowSchemaPyCapsule ) );
260+
auto py_array = csp::python::PyObjectPtr::own( PyCapsule_New( rb_array, "arrow_array", ReleaseArrowArrayPyCapsule ) );
261+
auto py_tuple = csp::python::PyObjectPtr::own( PyTuple_Pack( 2, py_schema.get(), py_array.get() ) );
262+
return csp::python::fromPython<DialectGenericType>( py_tuple.get() );
263+
}
264+
265+
bool next( DateTime & t, std::vector<DialectGenericType> & value ) override
266+
{
267+
m_curResult.clear();
268+
bool newRecordBatch = false;
269+
while( !m_finished )
270+
{
271+
// Slice current record batch
272+
auto new_ts = m_tsArray -> Value( m_curBatchIdx );
273+
if( new_ts * m_multiplier > m_endTime )
274+
{
275+
// Past the end time
276+
m_finished = true;
277+
break;
278+
}
279+
if( newRecordBatch && new_ts != m_curTs )
280+
{
281+
// Next timestamp encountered, return the current list of record batches
282+
value = m_curResult;
283+
m_time = csp::DateTime::fromNanoseconds( m_curTs * m_multiplier );
284+
t = m_time;
285+
return true;
286+
}
287+
m_curTs = new_ts;
288+
auto next_idx = findNextLargerTimestampIndex( m_curBatchIdx );
289+
auto slice = m_curRecordBatch -> Slice( m_curBatchIdx, next_idx - m_curBatchIdx );
290+
m_curResult.emplace_back( convertRecordBatchToPython( slice ) );
291+
m_curBatchIdx = next_idx;
292+
if( m_curBatchIdx != m_numRows )
293+
{
294+
// All rows for current timestamp have been found
295+
value = m_curResult;
296+
m_time = csp::DateTime::fromNanoseconds( m_curTs * m_multiplier );
297+
t = m_time;
298+
return true;
299+
}
300+
// Get the next record batch
301+
m_curRecordBatch = getNonEmptyRecordBatchFromSource();
302+
if( !m_curRecordBatch )
303+
{
304+
m_finished = true;
305+
break;
306+
}
307+
auto array = m_curRecordBatch -> GetColumnByName( m_tsColName );
308+
m_tsArray = std::static_pointer_cast<::arrow::TimestampArray>( array );
309+
m_numRows = m_tsArray -> length();
310+
m_curBatchIdx = 0;
311+
newRecordBatch = true;
312+
}
313+
if( !m_curResult.empty() )
314+
{
315+
value = m_curResult;
316+
m_time = csp::DateTime::fromNanoseconds( m_curTs * m_multiplier );
317+
t = m_time;
318+
return true;
319+
}
320+
return false;
321+
}
322+
323+
private:
324+
std::string m_tsColName;
325+
RecordBatchIterator m_source;
326+
int m_expectSmallBatches;
327+
bool m_finished;
328+
std::shared_ptr<::arrow::RecordBatch> m_curRecordBatch;
329+
std::shared_ptr<::arrow::TimestampArray> m_tsArray;
330+
long int m_multiplier, m_numRows, m_curTs, m_endTime, m_curBatchIdx;
331+
std::vector<DialectGenericType> m_curResult;
332+
DateTime m_time;
333+
};
334+
335+
};
336+
337+
#endif

cpp/csp/python/CMakeLists.txt

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,34 @@ target_compile_definitions(csptypesimpl PUBLIC RAPIDJSON_HAS_STDSTRING=1)
2424
target_link_libraries(csptypesimpl csp_core csp_types)
2525
target_compile_definitions(csptypesimpl PRIVATE CSPTYPESIMPL_EXPORTS=1)
2626

27+
find_package(Arrow REQUIRED)
28+
find_package(Parquet REQUIRED)
29+
30+
if(WIN32)
31+
if(CSP_USE_VCPKG)
32+
set(ARROW_PACKAGES_TO_LINK Arrow::arrow_static Parquet::parquet_static )
33+
target_compile_definitions(csp_parquet_adapter PUBLIC ARROW_STATIC)
34+
target_compile_definitions(csp_parquet_adapter PUBLIC PARQUET_STATIC)
35+
else()
36+
# use dynamic variants
37+
# Until we manage to get the fix for ws3_32.dll in arrow-16 into conda, manually fix the error here
38+
get_target_property(LINK_LIBS Arrow::arrow_shared INTERFACE_LINK_LIBRARIES)
39+
string(REPLACE "ws2_32.dll" "ws2_32" FIXED_LINK_LIBS "${LINK_LIBS}")
40+
set_target_properties(Arrow::arrow_shared PROPERTIES INTERFACE_LINK_LIBRARIES "${FIXED_LINK_LIBS}")
41+
set(ARROW_PACKAGES_TO_LINK parquet_shared arrow_shared)
42+
endif()
43+
else()
44+
if(CSP_USE_VCPKG)
45+
# use static variants
46+
set(ARROW_PACKAGES_TO_LINK parquet_static arrow_static)
47+
else()
48+
# use dynamic variants
49+
set(ARROW_PACKAGES_TO_LINK parquet arrow)
50+
endif()
51+
endif()
2752

2853
set(CSPIMPL_PUBLIC_HEADERS
54+
ArrowInputAdapter.h
2955
Common.h
3056
Conversions.h
3157
Exception.h
@@ -57,6 +83,7 @@ add_library(cspimpl SHARED
5783
NumpyConversions.cpp
5884
PyAdapterManager.cpp
5985
PyAdapterManagerWrapper.cpp
86+
PyArrowInputAdapter.cpp
6087
PyConstAdapter.cpp
6188
PyCppNode.cpp
6289
PyEngine.cpp
@@ -84,12 +111,11 @@ add_library(cspimpl SHARED
84111

85112
set_target_properties(cspimpl PROPERTIES PUBLIC_HEADER "${CSPIMPL_PUBLIC_HEADERS}")
86113

87-
target_link_libraries(cspimpl csptypesimpl csp_core csp_engine )
114+
target_link_libraries(cspimpl csptypesimpl csp_core csp_engine ${ARROW_PACKAGES_TO_LINK} )
88115

89116
target_compile_definitions(cspimpl PUBLIC NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION)
90117
target_compile_definitions(cspimpl PRIVATE CSPIMPL_EXPORTS=1)
91118

92-
93119
## Baselib c++ module
94120
add_library(cspbaselibimpl SHARED cspbaselibimpl.cpp)
95121
target_link_libraries(cspbaselibimpl cspimpl baselibimpl)

0 commit comments

Comments
 (0)