|
5 | 5 | import json
|
6 | 6 | from collections.abc import Iterable, Sized
|
7 | 7 | from functools import lru_cache
|
8 |
| -from types import ModuleType |
9 | 8 |
|
10 |
| -from packaging.version import parse as parse_version |
| 9 | +import pyarrow |
11 | 10 |
|
12 | 11 | import awkward as ak
|
13 | 12 | from awkward._backends.numpy import NumpyBackend
|
14 | 13 | from awkward._nplikes.numpy import Numpy
|
15 | 14 | from awkward._nplikes.numpy_like import NumpyMetadata
|
16 | 15 | from awkward._parameters import parameters_union
|
17 | 16 |
|
| 17 | +from .extn_types import AwkwardArrowType, to_awkwardarrow_storage_types |
| 18 | + |
18 | 19 | np = NumpyMetadata.instance()
|
19 | 20 | numpy = Numpy.instance()
|
20 | 21 |
|
21 |
| -try: |
22 |
| - import pyarrow |
23 |
| - |
24 |
| - error_message = None |
25 |
| - |
26 |
| -except ModuleNotFoundError: |
27 |
| - pyarrow = None |
28 |
| - error_message = """to use {0}, you must install pyarrow: |
29 |
| -
|
30 |
| - pip install pyarrow |
31 |
| -
|
32 |
| -or |
33 |
| -
|
34 |
| - conda install -c conda-forge pyarrow |
35 |
| -""" |
36 |
| - |
37 |
| -else: |
38 |
| - if parse_version(pyarrow.__version__) < parse_version("7.0.0"): |
39 |
| - pyarrow = None |
40 |
| - error_message = "pyarrow 7.0.0 or later required for {0}" |
41 |
| - |
42 |
| - |
43 |
| -def import_pyarrow(name: str) -> ModuleType: |
44 |
| - if pyarrow is None: |
45 |
| - raise ImportError(error_message.format(name)) |
46 |
| - return pyarrow |
47 |
| - |
48 |
| - |
49 |
| -def import_pyarrow_parquet(name: str) -> ModuleType: |
50 |
| - if pyarrow is None: |
51 |
| - raise ImportError(error_message.format(name)) |
52 |
| - |
53 |
| - import pyarrow.parquet as out |
54 |
| - |
55 |
| - return out |
56 |
| - |
57 |
| - |
58 |
| -def import_pyarrow_compute(name: str) -> ModuleType: |
59 |
| - if pyarrow is None: |
60 |
| - raise ImportError(error_message.format(name)) |
61 |
| - |
62 |
| - import pyarrow.compute as out |
63 |
| - |
64 |
| - return out |
65 |
| - |
66 |
| - |
67 |
| -if pyarrow is not None: |
68 |
| - |
69 |
| - class AwkwardArrowArray(pyarrow.ExtensionArray): |
70 |
| - def to_pylist(self): |
71 |
| - out = super().to_pylist() |
72 |
| - if ( |
73 |
| - isinstance(self.type, AwkwardArrowType) |
74 |
| - and self.type.node_type == "RecordArray" |
75 |
| - and self.type.record_is_tuple is True |
76 |
| - ): |
77 |
| - for i, x in enumerate(out): |
78 |
| - if x is not None: |
79 |
| - out[i] = tuple(x[str(j)] for j in range(len(x))) |
80 |
| - return out |
81 |
| - |
82 |
| - class AwkwardArrowType(pyarrow.ExtensionType): |
83 |
| - def __init__( |
84 |
| - self, |
85 |
| - storage_type, |
86 |
| - mask_type, |
87 |
| - node_type, |
88 |
| - mask_parameters, |
89 |
| - node_parameters, |
90 |
| - record_is_tuple, |
91 |
| - record_is_scalar, |
92 |
| - is_nonnullable_nulltype=False, |
93 |
| - ): |
94 |
| - self._mask_type = mask_type |
95 |
| - self._node_type = node_type |
96 |
| - self._mask_parameters = mask_parameters |
97 |
| - self._node_parameters = node_parameters |
98 |
| - self._record_is_tuple = record_is_tuple |
99 |
| - self._record_is_scalar = record_is_scalar |
100 |
| - self._is_nonnullable_nulltype = is_nonnullable_nulltype |
101 |
| - super().__init__(storage_type, "awkward") |
102 |
| - |
103 |
| - def __str__(self): |
104 |
| - return "ak:" + str(self.storage_type) |
105 |
| - |
106 |
| - def __repr__(self): |
107 |
| - return f"awkward<{self.storage_type!r}>" |
108 |
| - |
109 |
| - @property |
110 |
| - def mask_type(self): |
111 |
| - return self._mask_type |
112 |
| - |
113 |
| - @property |
114 |
| - def node_type(self): |
115 |
| - return self._node_type |
116 |
| - |
117 |
| - @property |
118 |
| - def mask_parameters(self): |
119 |
| - return self._mask_parameters |
120 |
| - |
121 |
| - @property |
122 |
| - def node_parameters(self): |
123 |
| - return self._node_parameters |
124 |
| - |
125 |
| - @property |
126 |
| - def record_is_tuple(self): |
127 |
| - return self._record_is_tuple |
128 |
| - |
129 |
| - @property |
130 |
| - def record_is_scalar(self): |
131 |
| - return self._record_is_scalar |
132 |
| - |
133 |
| - def __arrow_ext_class__(self): |
134 |
| - return AwkwardArrowArray |
135 |
| - |
136 |
| - def __arrow_ext_serialize__(self): |
137 |
| - return json.dumps( |
138 |
| - { |
139 |
| - "mask_type": self._mask_type, |
140 |
| - "node_type": self._node_type, |
141 |
| - "mask_parameters": self._mask_parameters, |
142 |
| - "node_parameters": self._node_parameters, |
143 |
| - "record_is_tuple": self._record_is_tuple, |
144 |
| - "record_is_scalar": self._record_is_scalar, |
145 |
| - "is_nonnullable_nulltype": self._is_nonnullable_nulltype, |
146 |
| - } |
147 |
| - ).encode(errors="surrogatescape") |
148 |
| - |
149 |
| - @classmethod |
150 |
| - def __arrow_ext_deserialize__(cls, storage_type, serialized): |
151 |
| - metadata = json.loads(serialized.decode(errors="surrogatescape")) |
152 |
| - return cls( |
153 |
| - storage_type, |
154 |
| - metadata["mask_type"], |
155 |
| - metadata["node_type"], |
156 |
| - metadata["mask_parameters"], |
157 |
| - metadata["node_parameters"], |
158 |
| - metadata["record_is_tuple"], |
159 |
| - metadata["record_is_scalar"], |
160 |
| - is_nonnullable_nulltype=metadata.get("is_nonnullable_nulltype", False), |
161 |
| - ) |
162 |
| - |
163 |
| - @property |
164 |
| - def num_buffers(self): |
165 |
| - return self.storage_type.num_buffers |
166 |
| - |
167 |
| - @property |
168 |
| - def num_fields(self): |
169 |
| - return self.storage_type.num_fields |
170 |
| - |
171 |
| - pyarrow.register_extension_type( |
172 |
| - AwkwardArrowType(pyarrow.null(), None, None, None, None, None, None) |
173 |
| - ) |
174 |
| - |
175 |
| - # order is important; _string_like[:2] vs _string_like[::2] |
176 |
| - _string_like = ( |
177 |
| - pyarrow.string(), |
178 |
| - pyarrow.large_string(), |
179 |
| - pyarrow.binary(), |
180 |
| - pyarrow.large_binary(), |
181 |
| - ) |
182 |
| - |
183 |
| - _pyarrow_to_numpy_dtype = { |
184 |
| - pyarrow.date32(): (True, np.dtype("M8[D]")), |
185 |
| - pyarrow.date64(): (False, np.dtype("M8[ms]")), |
186 |
| - pyarrow.time32("s"): (True, np.dtype("M8[s]")), |
187 |
| - pyarrow.time32("ms"): (True, np.dtype("M8[ms]")), |
188 |
| - pyarrow.time64("us"): (False, np.dtype("M8[us]")), |
189 |
| - pyarrow.time64("ns"): (False, np.dtype("M8[ns]")), |
190 |
| - pyarrow.timestamp("s"): (False, np.dtype("M8[s]")), |
191 |
| - pyarrow.timestamp("ms"): (False, np.dtype("M8[ms]")), |
192 |
| - pyarrow.timestamp("us"): (False, np.dtype("M8[us]")), |
193 |
| - pyarrow.timestamp("ns"): (False, np.dtype("M8[ns]")), |
194 |
| - pyarrow.duration("s"): (False, np.dtype("m8[s]")), |
195 |
| - pyarrow.duration("ms"): (False, np.dtype("m8[ms]")), |
196 |
| - pyarrow.duration("us"): (False, np.dtype("m8[us]")), |
197 |
| - pyarrow.duration("ns"): (False, np.dtype("m8[ns]")), |
198 |
| - } |
199 |
| - |
200 | 22 |
|
201 | 23 | def and_validbytes(validbytes1, validbytes2):
|
202 | 24 | if validbytes1 is None:
|
@@ -230,13 +52,6 @@ def to_null_count(validbytes, count_nulls):
|
230 | 52 | return len(validbytes) - numpy.count_nonzero(validbytes)
|
231 | 53 |
|
232 | 54 |
|
233 |
| -def to_awkwardarrow_storage_types(arrowtype): |
234 |
| - if isinstance(arrowtype, AwkwardArrowType): |
235 |
| - return arrowtype, arrowtype.storage_type |
236 |
| - else: |
237 |
| - return None, arrowtype |
238 |
| - |
239 |
| - |
240 | 55 | def node_parameters(awkwardarrow_type):
|
241 | 56 | if isinstance(awkwardarrow_type, AwkwardArrowType):
|
242 | 57 | return awkwardarrow_type.node_parameters
|
@@ -1161,3 +976,29 @@ def convert_to_array(layout, type=None):
|
1161 | 976 | return out
|
1162 | 977 | else:
|
1163 | 978 | return pyarrow.array(out, type=type)
|
| 979 | + |
| 980 | + |
| 981 | +# order is important; _string_like[:2] vs _string_like[::2] |
| 982 | +_string_like = ( |
| 983 | + pyarrow.string(), |
| 984 | + pyarrow.large_string(), |
| 985 | + pyarrow.binary(), |
| 986 | + pyarrow.large_binary(), |
| 987 | +) |
| 988 | + |
| 989 | +_pyarrow_to_numpy_dtype = { |
| 990 | + pyarrow.date32(): (True, np.dtype("M8[D]")), |
| 991 | + pyarrow.date64(): (False, np.dtype("M8[ms]")), |
| 992 | + pyarrow.time32("s"): (True, np.dtype("M8[s]")), |
| 993 | + pyarrow.time32("ms"): (True, np.dtype("M8[ms]")), |
| 994 | + pyarrow.time64("us"): (False, np.dtype("M8[us]")), |
| 995 | + pyarrow.time64("ns"): (False, np.dtype("M8[ns]")), |
| 996 | + pyarrow.timestamp("s"): (False, np.dtype("M8[s]")), |
| 997 | + pyarrow.timestamp("ms"): (False, np.dtype("M8[ms]")), |
| 998 | + pyarrow.timestamp("us"): (False, np.dtype("M8[us]")), |
| 999 | + pyarrow.timestamp("ns"): (False, np.dtype("M8[ns]")), |
| 1000 | + pyarrow.duration("s"): (False, np.dtype("m8[s]")), |
| 1001 | + pyarrow.duration("ms"): (False, np.dtype("m8[ms]")), |
| 1002 | + pyarrow.duration("us"): (False, np.dtype("m8[us]")), |
| 1003 | + pyarrow.duration("ns"): (False, np.dtype("m8[ns]")), |
| 1004 | +} |
0 commit comments