|
16 | 16 | # under the License.
|
17 | 17 |
|
18 | 18 | import contextlib
|
| 19 | +import functools |
19 | 20 | import os
|
20 | 21 | import subprocess
|
21 | 22 |
|
22 |
| -from .tester import Tester |
| 23 | +from . import cdata |
| 24 | +from .tester import Tester, CDataExporter, CDataImporter |
23 | 25 | from .util import run_cmd, log
|
24 | 26 | from ..utils.source import ARROW_ROOT_DEFAULT
|
25 | 27 |
|
26 | 28 |
|
27 |
| -_EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, "rust/target/debug") |
| 29 | +_EXE_PATH = os.environ.get( |
| 30 | + "ARROW_RUST_EXE_PATH", os.path.join(ARROW_ROOT_DEFAULT, "rust/target/debug") |
| 31 | +) |
28 | 32 | _INTEGRATION_EXE = os.path.join(_EXE_PATH, "arrow-json-integration-test")
|
29 | 33 | _STREAM_TO_FILE = os.path.join(_EXE_PATH, "arrow-stream-to-file")
|
30 | 34 | _FILE_TO_STREAM = os.path.join(_EXE_PATH, "arrow-file-to-stream")
|
|
37 | 41 | "localhost",
|
38 | 42 | ]
|
39 | 43 |
|
| 44 | +_INTEGRATION_DLL = os.path.join(_EXE_PATH, |
| 45 | + "libarrow_integration_testing" + cdata.dll_suffix) |
| 46 | + |
40 | 47 |
|
41 | 48 | class RustTester(Tester):
|
42 | 49 | PRODUCER = True
|
43 | 50 | CONSUMER = True
|
44 | 51 | FLIGHT_SERVER = True
|
45 | 52 | FLIGHT_CLIENT = True
|
| 53 | + C_DATA_SCHEMA_EXPORTER = True |
| 54 | + C_DATA_ARRAY_EXPORTER = True |
| 55 | + C_DATA_SCHEMA_IMPORTER = True |
| 56 | + C_DATA_ARRAY_IMPORTER = True |
46 | 57 |
|
47 | 58 | name = 'Rust'
|
48 | 59 |
|
@@ -117,3 +128,102 @@ def flight_request(self, port, json_path=None, scenario_name=None):
|
117 | 128 | if self.debug:
|
118 | 129 | log(' '.join(cmd))
|
119 | 130 | run_cmd(cmd)
|
| 131 | + |
| 132 | + def make_c_data_exporter(self): |
| 133 | + return RustCDataExporter(self.debug, self.args) |
| 134 | + |
| 135 | + def make_c_data_importer(self): |
| 136 | + return RustCDataImporter(self.debug, self.args) |
| 137 | + |
| 138 | + |
| 139 | +_rust_c_data_entrypoints = """ |
| 140 | + const char* arrow_rs_cdata_integration_export_schema_from_json( |
| 141 | + const char* json_path, uintptr_t out); |
| 142 | + const char* arrow_rs_cdata_integration_import_schema_and_compare_to_json( |
| 143 | + const char* json_path, uintptr_t c_schema); |
| 144 | +
|
| 145 | + const char* arrow_rs_cdata_integration_export_batch_from_json( |
| 146 | + const char* json_path, int num_batch, uintptr_t out); |
| 147 | + const char* arrow_rs_cdata_integration_import_batch_and_compare_to_json( |
| 148 | + const char* json_path, int num_batch, uintptr_t c_array); |
| 149 | +
|
| 150 | + void arrow_rs_free_error(const char*); |
| 151 | + """ |
| 152 | + |
| 153 | + |
| 154 | +@functools.lru_cache |
| 155 | +def _load_ffi(ffi, lib_path=_INTEGRATION_DLL): |
| 156 | + ffi.cdef(_rust_c_data_entrypoints) |
| 157 | + dll = ffi.dlopen(lib_path) |
| 158 | + return dll |
| 159 | + |
| 160 | + |
| 161 | +class _CDataBase: |
| 162 | + |
| 163 | + def __init__(self, debug, args): |
| 164 | + self.debug = debug |
| 165 | + self.args = args |
| 166 | + self.ffi = cdata.ffi() |
| 167 | + self.dll = _load_ffi(self.ffi) |
| 168 | + |
| 169 | + def _pointer_to_int(self, c_ptr): |
| 170 | + return self.ffi.cast('uintptr_t', c_ptr) |
| 171 | + |
| 172 | + def _check_rust_error(self, rs_error): |
| 173 | + """ |
| 174 | + Check a `const char*` error return from an integration entrypoint. |
| 175 | +
|
| 176 | + A null means success, a non-empty string is an error message. |
| 177 | + The string is dynamically allocated on the Rust side. |
| 178 | + """ |
| 179 | + assert self.ffi.typeof(rs_error) is self.ffi.typeof("const char*") |
| 180 | + if rs_error != self.ffi.NULL: |
| 181 | + try: |
| 182 | + error = self.ffi.string(rs_error).decode( |
| 183 | + 'utf8', errors='replace') |
| 184 | + raise RuntimeError( |
| 185 | + f"Rust C Data Integration call failed: {error}") |
| 186 | + finally: |
| 187 | + self.dll.arrow_rs_free_error(rs_error) |
| 188 | + |
| 189 | + |
| 190 | +class RustCDataExporter(CDataExporter, _CDataBase): |
| 191 | + |
| 192 | + def export_schema_from_json(self, json_path, c_schema_ptr): |
| 193 | + rs_error = self.dll.arrow_rs_cdata_integration_export_schema_from_json( |
| 194 | + str(json_path).encode(), self._pointer_to_int(c_schema_ptr)) |
| 195 | + self._check_rust_error(rs_error) |
| 196 | + |
| 197 | + def export_batch_from_json(self, json_path, num_batch, c_array_ptr): |
| 198 | + rs_error = self.dll.arrow_rs_cdata_integration_export_batch_from_json( |
| 199 | + str(json_path).encode(), num_batch, |
| 200 | + self._pointer_to_int(c_array_ptr)) |
| 201 | + self._check_rust_error(rs_error) |
| 202 | + |
| 203 | + @property |
| 204 | + def supports_releasing_memory(self): |
| 205 | + return True |
| 206 | + |
| 207 | + def record_allocation_state(self): |
| 208 | + # FIXME is it possible to measure the amount of Rust-allocated memory? |
| 209 | + return 0 |
| 210 | + |
| 211 | + |
| 212 | +class RustCDataImporter(CDataImporter, _CDataBase): |
| 213 | + |
| 214 | + def import_schema_and_compare_to_json(self, json_path, c_schema_ptr): |
| 215 | + rs_error = \ |
| 216 | + self.dll.arrow_rs_cdata_integration_import_schema_and_compare_to_json( |
| 217 | + str(json_path).encode(), self._pointer_to_int(c_schema_ptr)) |
| 218 | + self._check_rust_error(rs_error) |
| 219 | + |
| 220 | + def import_batch_and_compare_to_json(self, json_path, num_batch, |
| 221 | + c_array_ptr): |
| 222 | + rs_error = \ |
| 223 | + self.dll.arrow_rs_cdata_integration_import_batch_and_compare_to_json( |
| 224 | + str(json_path).encode(), num_batch, self._pointer_to_int(c_array_ptr)) |
| 225 | + self._check_rust_error(rs_error) |
| 226 | + |
| 227 | + @property |
| 228 | + def supports_releasing_memory(self): |
| 229 | + return True |
0 commit comments