Skip to content

Commit 03f6169

Browse files
tcawlfieldianna
andauthored
feat: 3154 from_parquet should be able to read partial columns (#3156)
* feat: 3154 - Adding unit test and fix * Taking this opportunity to fix an issue with 2772 test This had been bugging me, and with a better understanding now I can improve the unit tests. Mostly for future readability and correctness. * Removing commented-out test return value --------- Co-authored-by: Ianna Osborne <[email protected]>
1 parent 6f688e9 commit 03f6169

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

src/awkward/_connect/pyarrow/table_conv.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
13
from __future__ import annotations
24

35
import json
@@ -128,11 +130,34 @@ def native_arrow_field_to_akarraytype(
128130
fields = _fields_of_strg_type(storage_type)
129131
if len(fields) > 0:
130132
# We need to replace storage_type with one that contains AwkwardArrowTypes.
131-
awkwardized_fields = [
132-
native_arrow_field_to_akarraytype(field, meta) # Recurse
133-
for field, meta in zip(fields, metadata["subfield_metadata"])
134-
]
133+
sub_meta = metadata["subfield_metadata"]
134+
awkwardized_fields = None # Temporary
135+
if len(sub_meta) == len(fields):
136+
awkwardized_fields = [
137+
native_arrow_field_to_akarraytype(field, meta) # Recurse
138+
for field, meta in zip(fields, metadata["subfield_metadata"])
139+
]
140+
elif len(fields) < len(sub_meta):
141+
# If a user has read a partial column, we can have fewer Arrow fields than the original.
142+
sub_meta_dict = {sm["field_name"]: sm for sm in sub_meta}
143+
awkwardized_fields = []
144+
for field in fields:
145+
if field.name in sub_meta_dict:
146+
awkwardized_fields.append(
147+
native_arrow_field_to_akarraytype(
148+
field, sub_meta_dict[field.name]
149+
)
150+
)
151+
else:
152+
raise ValueError(
153+
f"Cannot find Awkward metadata for sub-field {field.name}"
154+
)
155+
else:
156+
raise ValueError(
157+
f"Not enough fields in Awkward metadata. Have {len(sub_meta)} need at least {len(fields)}."
158+
)
135159
storage_type = _make_pyarrow_type_like(storage_type, awkwardized_fields)
160+
136161
ak_type = AwkwardArrowType._from_metadata_object(storage_type, metadata)
137162
return pyarrow.field(ntv_field.name, type=ak_type, nullable=ntv_field.nullable)
138163

tests/test_2772_parquet_extn_array_metadata.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import io
67
import os
78

89
import numpy as np
@@ -137,6 +138,22 @@ def test_array_conversions(akarray, as_dict):
137138
rt_array = ak.from_arrow(as_extn, highlevel=True)
138139
assert to_list(rt_array) == to_list(akarray)
139140

141+
# Deeper test of types
142+
akarray_high = ak.Array(akarray)
143+
if akarray_high.type.content.parameters.get("__categorical__", False) == as_dict:
144+
# as_dict is supposed to go hand-in-hand with __categorical__: True, and if it
145+
# does not, we do not round-trip perfectly. So only test when this is set correctly.
146+
assert rt_array.type == akarray_high.type
147+
148+
ak_type_str_orig = io.StringIO()
149+
ak_type_str_rtrp = io.StringIO()
150+
akarray_high.type.show(stream=ak_type_str_orig)
151+
rt_array.type.show(stream=ak_type_str_rtrp)
152+
if ak_type_str_orig.getvalue() != ak_type_str_rtrp.getvalue():
153+
print(" Original type:", ak_type_str_orig.getvalue())
154+
print(" Rnd-trip type:", ak_type_str_rtrp.getvalue())
155+
assert ak_type_str_orig.getvalue() == ak_type_str_rtrp.getvalue()
156+
140157

141158
def test_table_conversion():
142159
ak_tbl_like = ak.Array(
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
# ruff: noqa: E402
3+
4+
from __future__ import annotations
5+
6+
import os
7+
8+
import pytest
9+
10+
import awkward as ak
11+
12+
pa = pytest.importorskip("pyarrow")
13+
pq = pytest.importorskip("pyarrow.parquet")
14+
15+
16+
def test_parquet_subcolumn_select(tmp_path):
17+
ak_tbl = ak.Array(
18+
{
19+
"a": [
20+
{"lbl": "item 1", "idx": 11, "ids": [1, 2, 3]},
21+
{"lbl": "item 2", "idx": 12, "ids": [51, 52]},
22+
{"lbl": "item 3", "idx": 13, "ids": [61, 62, 63, 64]},
23+
],
24+
"b": [
25+
[[111, 112], [121, 122]],
26+
[[211, 212], [221, 222]],
27+
[[311, 312], [321, 322]],
28+
],
29+
}
30+
)
31+
parquet_file = os.path.join(tmp_path, "test_3514.parquet")
32+
ak.to_parquet(ak_tbl, parquet_file)
33+
34+
selection = ak.from_parquet(parquet_file, columns=["a.ids", "b"])
35+
assert selection["a"].to_list() == [
36+
{"ids": [1, 2, 3]},
37+
{"ids": [51, 52]},
38+
{"ids": [61, 62, 63, 64]},
39+
]
40+
assert selection["b"].to_list() == ak_tbl["b"].to_list()

0 commit comments

Comments
 (0)