Skip to content

Commit ef2e08f

Browse files
tcawlfieldjpivarskipre-commit-ci[bot]
authored
feat: 2772 pyarrow doesnt permit selective reading with extensionarray (#3127)
With this commit, when writing Parquet files, we automatically convert Awkward Arrow extension types in pyarrow Tables to native ones, while adding the metadata needed to reconstruct the extension types to the table (schema) metadata. This new behavior is exercised when passing `extensionarray=True` to methods `awkward.to_parquet()` and similar. When reading Parquet files, if our Table metadata exists, we use it to convert from native pyarrow types back to AwkwardExtension after reading. Squashed commits: * POC for 2772, outbound table conversion * Changing table-wide metadata Each object now includes "field_name". Each composite object, including the table itself, is represented as a list of objects. This feels like it has more consistency and robustness. * Adding convert_native_arrow_table_to_awkward This does not work yet. Table.cast() seems to fail me here. But the schema generation appears to work so far. * Various fixes to table conversions * Replacing Table.cast() with custom function :replace_schema() * Replacing dictionary of lambdas with function having elif isinstance chain * This is currently very poorly tested * Ruff formatting * Improvements to pyarrow_table_conv but still issues Handling all the pyarrow types that we use now, but there are still errors converting native DictionaryType arrays to awkward-extension-types. * Fixing bug in array_with_replacement_type Turns out you need AwkwardArrowArray.from_storage to create extension-type dictionary arrays. Strange but I'm evidently not the first poor soul to bump into this. * Adding unit testing, fixing a couple bugs The unit tests do not yet cover Parquet file operations, which will likely be in the next commit. * Adding hooks to parquet read & write Also expanded test_2772 a bit, trying to reproduce errors from test_1440. But instead of reproducing these, I found new errors. Ugh. Checking this in because it's just where I'm at right now. * Ruff-fmt fixes * Making progress Added new test for actually doing selective read. This required changing the top-level metadata from list to json object. Also fixed a bug, when converting table to navite, keep any existing table metadata entries. * Fixing another bug: convert each row group when writing Fixes failures in test 2968 * pyarrow_table_conv -- change our new table metadata key name Co-authored-by: Jim Pivarski <[email protected]> * Some stylistic improvements * Commented-out a messy assertion in test_2772 * style: pre-commit fixes * Moving awkward._connect.pyarrow into a package * Restructuring ._connect.pyarrow package This makes things more natural to use from outside, with or without pyarrow installed. * Fixing unused imports and other Ruffage * Fixing Ruffage, this time for sure * Fixes for old versions of pyarrow * Small fixes Moving to_awkwardarrow_storage_types from .conversions to .extn_types. * Adding BSD licenses, moving a commented test The failing test is being moved to a new file, to be added later. --------- Co-authored-by: Jim Pivarski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f16c0e7 commit ef2e08f

File tree

7 files changed

+768
-193
lines changed

7 files changed

+768
-193
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
from types import ModuleType
6+
7+
from packaging.version import parse as parse_version
8+
9+
__all__ = [
10+
"import_pyarrow",
11+
"import_pyarrow_parquet",
12+
"import_pyarrow_compute",
13+
"AwkwardArrowArray",
14+
"AwkwardArrowType",
15+
"and_validbytes",
16+
"convert_to_array",
17+
"direct_Content_subclass",
18+
"direct_Content_subclass_name",
19+
"form_handle_arrow",
20+
"handle_arrow",
21+
"popbuffers",
22+
"remove_optiontype",
23+
"to_awkwardarrow_storage_types",
24+
"to_awkwardarrow_type",
25+
"to_length",
26+
"to_null_count",
27+
"to_validbits",
28+
"convert_awkward_arrow_table_to_native",
29+
"convert_native_arrow_table_to_awkward",
30+
]
31+
32+
try:
33+
import pyarrow
34+
35+
error_message = None
36+
37+
except ModuleNotFoundError:
38+
pyarrow = None
39+
error_message = """to use {0}, you must install pyarrow:
40+
41+
pip install pyarrow
42+
43+
or
44+
45+
conda install -c conda-forge pyarrow
46+
"""
47+
48+
else:
49+
if parse_version(pyarrow.__version__) < parse_version("7.0.0"):
50+
pyarrow = None
51+
error_message = "pyarrow 7.0.0 or later required for {0}"
52+
53+
if error_message is None:
54+
from .conversions import (
55+
and_validbytes,
56+
convert_to_array,
57+
direct_Content_subclass,
58+
direct_Content_subclass_name,
59+
form_handle_arrow,
60+
handle_arrow,
61+
popbuffers,
62+
remove_optiontype,
63+
to_awkwardarrow_type,
64+
to_length,
65+
to_null_count,
66+
to_validbits,
67+
)
68+
from .extn_types import (
69+
AwkwardArrowArray,
70+
AwkwardArrowType,
71+
to_awkwardarrow_storage_types,
72+
)
73+
from .table_conv import (
74+
convert_awkward_arrow_table_to_native,
75+
convert_native_arrow_table_to_awkward,
76+
)
77+
else:
78+
AwkwardArrowArray = None
79+
AwkwardArrowType = None
80+
81+
def nothing_without_pyarrow(*args, **kwargs):
82+
raise NotImplementedError(
83+
"This function requires pyarrow, which is not installed."
84+
)
85+
86+
convert_awkward_arrow_table_to_native = nothing_without_pyarrow
87+
convert_native_arrow_table_to_awkward = nothing_without_pyarrow
88+
and_validbytes = nothing_without_pyarrow
89+
to_validbits = nothing_without_pyarrow
90+
to_length = nothing_without_pyarrow
91+
to_null_count = nothing_without_pyarrow
92+
to_awkwardarrow_storage_types = nothing_without_pyarrow
93+
popbuffers = nothing_without_pyarrow
94+
handle_arrow = nothing_without_pyarrow
95+
convert_to_array = nothing_without_pyarrow
96+
to_awkwardarrow_type = nothing_without_pyarrow
97+
direct_Content_subclass = nothing_without_pyarrow
98+
direct_Content_subclass_name = nothing_without_pyarrow
99+
remove_optiontype = nothing_without_pyarrow
100+
form_handle_arrow = nothing_without_pyarrow
101+
102+
103+
def import_pyarrow(name: str) -> ModuleType:
104+
if pyarrow is None:
105+
raise ImportError(error_message.format(name))
106+
return pyarrow
107+
108+
109+
def import_pyarrow_parquet(name: str) -> ModuleType:
110+
if pyarrow is None:
111+
raise ImportError(error_message.format(name))
112+
113+
import pyarrow.parquet as out
114+
115+
return out
116+
117+
118+
def import_pyarrow_compute(name: str) -> ModuleType:
119+
if pyarrow is None:
120+
raise ImportError(error_message.format(name))
121+
122+
import pyarrow.compute as out
123+
124+
return out

src/awkward/_connect/pyarrow.py renamed to src/awkward/_connect/pyarrow/conversions.py

Lines changed: 29 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -5,198 +5,20 @@
55
import json
66
from collections.abc import Iterable, Sized
77
from functools import lru_cache
8-
from types import ModuleType
98

10-
from packaging.version import parse as parse_version
9+
import pyarrow
1110

1211
import awkward as ak
1312
from awkward._backends.numpy import NumpyBackend
1413
from awkward._nplikes.numpy import Numpy
1514
from awkward._nplikes.numpy_like import NumpyMetadata
1615
from awkward._parameters import parameters_union
1716

17+
from .extn_types import AwkwardArrowType, to_awkwardarrow_storage_types
18+
1819
np = NumpyMetadata.instance()
1920
numpy = Numpy.instance()
2021

21-
try:
22-
import pyarrow
23-
24-
error_message = None
25-
26-
except ModuleNotFoundError:
27-
pyarrow = None
28-
error_message = """to use {0}, you must install pyarrow:
29-
30-
pip install pyarrow
31-
32-
or
33-
34-
conda install -c conda-forge pyarrow
35-
"""
36-
37-
else:
38-
if parse_version(pyarrow.__version__) < parse_version("7.0.0"):
39-
pyarrow = None
40-
error_message = "pyarrow 7.0.0 or later required for {0}"
41-
42-
43-
def import_pyarrow(name: str) -> ModuleType:
44-
if pyarrow is None:
45-
raise ImportError(error_message.format(name))
46-
return pyarrow
47-
48-
49-
def import_pyarrow_parquet(name: str) -> ModuleType:
50-
if pyarrow is None:
51-
raise ImportError(error_message.format(name))
52-
53-
import pyarrow.parquet as out
54-
55-
return out
56-
57-
58-
def import_pyarrow_compute(name: str) -> ModuleType:
59-
if pyarrow is None:
60-
raise ImportError(error_message.format(name))
61-
62-
import pyarrow.compute as out
63-
64-
return out
65-
66-
67-
if pyarrow is not None:
68-
69-
class AwkwardArrowArray(pyarrow.ExtensionArray):
70-
def to_pylist(self):
71-
out = super().to_pylist()
72-
if (
73-
isinstance(self.type, AwkwardArrowType)
74-
and self.type.node_type == "RecordArray"
75-
and self.type.record_is_tuple is True
76-
):
77-
for i, x in enumerate(out):
78-
if x is not None:
79-
out[i] = tuple(x[str(j)] for j in range(len(x)))
80-
return out
81-
82-
class AwkwardArrowType(pyarrow.ExtensionType):
83-
def __init__(
84-
self,
85-
storage_type,
86-
mask_type,
87-
node_type,
88-
mask_parameters,
89-
node_parameters,
90-
record_is_tuple,
91-
record_is_scalar,
92-
is_nonnullable_nulltype=False,
93-
):
94-
self._mask_type = mask_type
95-
self._node_type = node_type
96-
self._mask_parameters = mask_parameters
97-
self._node_parameters = node_parameters
98-
self._record_is_tuple = record_is_tuple
99-
self._record_is_scalar = record_is_scalar
100-
self._is_nonnullable_nulltype = is_nonnullable_nulltype
101-
super().__init__(storage_type, "awkward")
102-
103-
def __str__(self):
104-
return "ak:" + str(self.storage_type)
105-
106-
def __repr__(self):
107-
return f"awkward<{self.storage_type!r}>"
108-
109-
@property
110-
def mask_type(self):
111-
return self._mask_type
112-
113-
@property
114-
def node_type(self):
115-
return self._node_type
116-
117-
@property
118-
def mask_parameters(self):
119-
return self._mask_parameters
120-
121-
@property
122-
def node_parameters(self):
123-
return self._node_parameters
124-
125-
@property
126-
def record_is_tuple(self):
127-
return self._record_is_tuple
128-
129-
@property
130-
def record_is_scalar(self):
131-
return self._record_is_scalar
132-
133-
def __arrow_ext_class__(self):
134-
return AwkwardArrowArray
135-
136-
def __arrow_ext_serialize__(self):
137-
return json.dumps(
138-
{
139-
"mask_type": self._mask_type,
140-
"node_type": self._node_type,
141-
"mask_parameters": self._mask_parameters,
142-
"node_parameters": self._node_parameters,
143-
"record_is_tuple": self._record_is_tuple,
144-
"record_is_scalar": self._record_is_scalar,
145-
"is_nonnullable_nulltype": self._is_nonnullable_nulltype,
146-
}
147-
).encode(errors="surrogatescape")
148-
149-
@classmethod
150-
def __arrow_ext_deserialize__(cls, storage_type, serialized):
151-
metadata = json.loads(serialized.decode(errors="surrogatescape"))
152-
return cls(
153-
storage_type,
154-
metadata["mask_type"],
155-
metadata["node_type"],
156-
metadata["mask_parameters"],
157-
metadata["node_parameters"],
158-
metadata["record_is_tuple"],
159-
metadata["record_is_scalar"],
160-
is_nonnullable_nulltype=metadata.get("is_nonnullable_nulltype", False),
161-
)
162-
163-
@property
164-
def num_buffers(self):
165-
return self.storage_type.num_buffers
166-
167-
@property
168-
def num_fields(self):
169-
return self.storage_type.num_fields
170-
171-
pyarrow.register_extension_type(
172-
AwkwardArrowType(pyarrow.null(), None, None, None, None, None, None)
173-
)
174-
175-
# order is important; _string_like[:2] vs _string_like[::2]
176-
_string_like = (
177-
pyarrow.string(),
178-
pyarrow.large_string(),
179-
pyarrow.binary(),
180-
pyarrow.large_binary(),
181-
)
182-
183-
_pyarrow_to_numpy_dtype = {
184-
pyarrow.date32(): (True, np.dtype("M8[D]")),
185-
pyarrow.date64(): (False, np.dtype("M8[ms]")),
186-
pyarrow.time32("s"): (True, np.dtype("M8[s]")),
187-
pyarrow.time32("ms"): (True, np.dtype("M8[ms]")),
188-
pyarrow.time64("us"): (False, np.dtype("M8[us]")),
189-
pyarrow.time64("ns"): (False, np.dtype("M8[ns]")),
190-
pyarrow.timestamp("s"): (False, np.dtype("M8[s]")),
191-
pyarrow.timestamp("ms"): (False, np.dtype("M8[ms]")),
192-
pyarrow.timestamp("us"): (False, np.dtype("M8[us]")),
193-
pyarrow.timestamp("ns"): (False, np.dtype("M8[ns]")),
194-
pyarrow.duration("s"): (False, np.dtype("m8[s]")),
195-
pyarrow.duration("ms"): (False, np.dtype("m8[ms]")),
196-
pyarrow.duration("us"): (False, np.dtype("m8[us]")),
197-
pyarrow.duration("ns"): (False, np.dtype("m8[ns]")),
198-
}
199-
20022

20123
def and_validbytes(validbytes1, validbytes2):
20224
if validbytes1 is None:
@@ -230,13 +52,6 @@ def to_null_count(validbytes, count_nulls):
23052
return len(validbytes) - numpy.count_nonzero(validbytes)
23153

23254

233-
def to_awkwardarrow_storage_types(arrowtype):
234-
if isinstance(arrowtype, AwkwardArrowType):
235-
return arrowtype, arrowtype.storage_type
236-
else:
237-
return None, arrowtype
238-
239-
24055
def node_parameters(awkwardarrow_type):
24156
if isinstance(awkwardarrow_type, AwkwardArrowType):
24257
return awkwardarrow_type.node_parameters
@@ -1161,3 +976,29 @@ def convert_to_array(layout, type=None):
1161976
return out
1162977
else:
1163978
return pyarrow.array(out, type=type)
979+
980+
981+
# order is important; _string_like[:2] vs _string_like[::2]
982+
_string_like = (
983+
pyarrow.string(),
984+
pyarrow.large_string(),
985+
pyarrow.binary(),
986+
pyarrow.large_binary(),
987+
)
988+
989+
_pyarrow_to_numpy_dtype = {
990+
pyarrow.date32(): (True, np.dtype("M8[D]")),
991+
pyarrow.date64(): (False, np.dtype("M8[ms]")),
992+
pyarrow.time32("s"): (True, np.dtype("M8[s]")),
993+
pyarrow.time32("ms"): (True, np.dtype("M8[ms]")),
994+
pyarrow.time64("us"): (False, np.dtype("M8[us]")),
995+
pyarrow.time64("ns"): (False, np.dtype("M8[ns]")),
996+
pyarrow.timestamp("s"): (False, np.dtype("M8[s]")),
997+
pyarrow.timestamp("ms"): (False, np.dtype("M8[ms]")),
998+
pyarrow.timestamp("us"): (False, np.dtype("M8[us]")),
999+
pyarrow.timestamp("ns"): (False, np.dtype("M8[ns]")),
1000+
pyarrow.duration("s"): (False, np.dtype("m8[s]")),
1001+
pyarrow.duration("ms"): (False, np.dtype("m8[ms]")),
1002+
pyarrow.duration("us"): (False, np.dtype("m8[us]")),
1003+
pyarrow.duration("ns"): (False, np.dtype("m8[ns]")),
1004+
}

0 commit comments

Comments
 (0)