Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
"Topic :: Software Development",
]
dependencies = [
"awkward >=2.5.1",
"awkward @ git+https://github.com/scikit-hep/awkward.git",
"dask >=2023.04.0,<2025.4.0",
"cachetools",
"typing_extensions >=4.8.0",
Expand Down Expand Up @@ -86,6 +86,9 @@ awkward = "dask_awkward_sizeof:register"
[tool.hatch.build.targets.wheel]
packages = ["src/dask_awkward", "src/dask_awkward_sizeof"]

[tool.hatch.metadata]
allow-direct-references = true

[project.entry-points."awkward.pickle.reduce"]
dask_awkward = "dask_awkward.pickle:plugin"

Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
combinations,
copy,
drop_none,
enforce_type,
fill_none,
firsts,
flatten,
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
combinations,
copy,
drop_none,
enforce_type,
fill_none,
firsts,
flatten,
Expand Down
45 changes: 23 additions & 22 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"combinations",
"copy",
"drop_none",
"enforce_type",
"fill_none",
"firsts",
"flatten",
Expand Down Expand Up @@ -1349,6 +1350,28 @@ def zip(
)


@borrow_docstring(ak.enforce_type)
def enforce_type(
array: Array,
type: str | dict | Type,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")

return map_partitions(
ak.enforce_type,
array,
label="enforce-type",
type=type,
behavior=behavior,
attrs=attrs,
output_divisions=1,
)


def _repartition_func(*stuff):
import builtins

Expand Down Expand Up @@ -1443,25 +1466,3 @@ def simple_repartition_layer(
else:
raise ValueError
return layer, new_divisions


@borrow_docstring(ak.enforce_type)
def enforce_type(
array: Array,
type: str | dict | Type,
highlevel: bool = True,
behavior: Mapping | None = None,
attrs: Mapping[str, Any] | None = None,
) -> Array:
if not highlevel:
raise ValueError("Only highlevel=True is supported")

return map_partitions(
ak.enforce_type,
array,
label="enforce-type",
type=type,
behavior=behavior,
attrs=attrs,
output_divisions=1,
)
254 changes: 254 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,257 @@ def test_repartition_uneven(daa):
assert daa1.npartitions == 4
out = daa1.compute()
assert out.tolist() == daa.compute()[:12].tolist()


def test_enforce_type_record():
a = [{"x": [1, 2]}, {"x": [3, 4]}]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "{x: var * float64}"),
ak.enforce_type(caa, "{x: var * float64}"),
)

assert_eq(
dak.enforce_type(daa, "{x: var * int64, y: ?int64}"),
ak.enforce_type(caa, "{x: var * int64, y: ?int64}"),
)

assert_eq(
dak.enforce_type(daa, "{y: ?var * float64}"),
ak.enforce_type(caa, "{y: ?var * float64}"),
)

assert_eq(
dak.enforce_type(daa, "{}"),
ak.enforce_type(caa, "{}"),
)

with pytest.raises(ValueError, match=r"converted between records and tuples"):
dak.enforce_type(daa, "(var * float64)").compute()

with pytest.raises(
TypeError, match=r"can only add new fields to a record if they are option types"
):
dak.enforce_type(daa, "{y: var * float64}").compute()


def test_enforce_type_tuple():
a = [([1, 2],), ([3, 4],)]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "(var * bool)"),
ak.enforce_type(caa, "(var * bool)"),
)

assert_eq(
dak.enforce_type(daa, "(var * int64, ?float32)"),
ak.enforce_type(caa, "(var * int64, ?float32)"),
)

assert_eq(
dak.enforce_type(daa, "()"),
ak.enforce_type(caa, "()"),
)

with pytest.raises(ValueError, match=r"converted between records and tuples"):
dak.enforce_type(daa, "{x: var * float64}").compute()

with pytest.raises(
TypeError, match=r"can only add new slots to a tuple if they are option types"
):
dak.enforce_type(daa, "(var * int64, float32)").compute()


def test_enforce_type_list():
a = [[1, 2, 3], [4, 5]]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "var * float64"),
ak.enforce_type(caa, "var * float64"),
)

a_regular = [[1, 2, 3], [4, 5, 6]]
caa_regular = ak.Array(a_regular)
daa_regular = dak.from_awkward(caa_regular, npartitions=1)
assert_eq(
dak.enforce_type(daa_regular, "3 * int64"),
ak.enforce_type(caa_regular, "3 * int64"),
)

with pytest.raises(ValueError, match=r"different size"):
dak.enforce_type(daa_regular, "4 * int64").compute()

caa_reg = ak.to_regular(ak.Array(a_regular), axis=-1)
daa_reg = dak.from_awkward(caa_reg, npartitions=1)
assert_eq(
dak.enforce_type(daa_reg, "var * int64"),
ak.enforce_type(caa_reg, "var * int64"),
)


def test_enforce_type_option():
a = [1, None, 2, 3]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "?float64"),
ak.enforce_type(caa, "?float64"),
)

caa_no_none = ak.Array([1, None, 2, 3])[:1]
daa_no_none = dak.from_awkward(caa_no_none, npartitions=1)
assert_eq(
dak.enforce_type(daa_no_none, "int64"),
ak.enforce_type(caa_no_none, "int64"),
)

with pytest.raises(ValueError, match=r"if there are no missing values"):
dak.enforce_type(daa, "int64").compute()

a_no_option = [1, 2, 3, 4]
caa_no_option = ak.Array(a_no_option)
daa_no_option = dak.from_awkward(caa_no_option, npartitions=1)
assert_eq(
dak.enforce_type(daa_no_option, "?int64"),
ak.enforce_type(caa_no_option, "?int64"),
)

assert_eq(
dak.enforce_type(daa, "?unknown"),
ak.enforce_type(caa, "?unknown"),
)


def test_enforce_type_numpy():
a = [1, 2, 3, 4]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "float32"),
ak.enforce_type(caa, "float32"),
)

with pytest.raises(TypeError):
dak.enforce_type(daa, "string").compute()

with pytest.raises(TypeError):
dak.enforce_type(daa, "var * int64").compute()

with pytest.raises(TypeError):
dak.enforce_type(daa, "2 * float32").compute()

a_2d = np.zeros((2, 3))
caa_2d = ak.from_numpy(a_2d, regulararray=False)
daa_2d = dak.from_awkward(caa_2d, npartitions=1)
assert_eq(
dak.enforce_type(daa_2d, "var * int64"),
ak.enforce_type(caa_2d, "var * int64"),
)

with pytest.raises(TypeError):
dak.enforce_type(daa_2d, "int64").compute()

with pytest.raises(TypeError):
dak.enforce_type(daa_2d, "float32").compute()


def test_enforce_type_union():
a = [1, 2, 3, 4]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "union[int64, string]"),
ak.enforce_type(caa, "union[int64, string]"),
)

a_union = [1, "hi", "bye"]
caa_union = ak.Array(a_union)[1:2]
daa_union = dak.from_awkward(caa_union, npartitions=1)
assert_eq(
dak.enforce_type(daa_union, "string"),
ak.enforce_type(caa_union, "string"),
)

caa_union_full = ak.Array(a_union)
daa_union_full = dak.from_awkward(caa_union_full, npartitions=1)
assert_eq(
dak.enforce_type(daa_union_full, "union[int64, string]"),
ak.enforce_type(caa_union_full, "union[int64, string]"),
)

assert_eq(
dak.enforce_type(daa_union_full, "union[int64, string, datetime64]"),
ak.enforce_type(caa_union_full, "union[int64, string, datetime64]"),
)

assert_eq(
dak.enforce_type(daa_union_full, "union[float32, string]"),
ak.enforce_type(caa_union_full, "union[float32, string]"),
)

caa_union_3 = ak.Array([1, "hi", [1j, 2j]])[:2]
daa_union_3 = dak.from_awkward(caa_union_3, npartitions=1)
assert_eq(
dak.enforce_type(daa_union_3, "union[int64, string]"),
ak.enforce_type(caa_union_3, "union[int64, string]"),
)

with pytest.raises(TypeError):
dak.enforce_type(daa_union_full, "union[int64, bool]").compute()

with pytest.raises(ValueError):
dak.enforce_type(daa_union_full, "string").compute()

caa_union_bool = ak.Array([1, "hi", False])
daa_union_bool = dak.from_awkward(caa_union_bool, npartitions=1)
with pytest.raises(TypeError):
dak.enforce_type(daa_union_bool, "union[datetime64, string, float32]").compute()

caa_int = ak.Array([1, 2])
daa_int = dak.from_awkward(caa_int, npartitions=1)
with pytest.raises(TypeError):
dak.enforce_type(daa_int, "union[var * int64, string]").compute()


def test_enforce_type_string():
a = ["hello world", "foo"]
caa = ak.Array(a)
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "bytes"),
ak.enforce_type(caa, "bytes"),
)

a_bytes = [b"hello world", b"foo"]
caa_bytes = ak.Array(a_bytes)
daa_bytes = dak.from_awkward(caa_bytes, npartitions=1)
assert_eq(
dak.enforce_type(daa_bytes, "string"),
ak.enforce_type(caa_bytes, "string"),
)

assert_eq(
dak.enforce_type(daa_bytes, "var * int64"),
ak.enforce_type(caa_bytes, "var * int64"),
)


def test_enforce_type_nested():
a = [{"x": [1, 2, None]}, None, {"x": [3, 4, None]}]
caa = ak.Array(a)[[0, 2], :2]
daa = dak.from_awkward(caa, npartitions=1)
assert_eq(
dak.enforce_type(daa, "?{x: var * ?float32}"),
ak.enforce_type(caa, "?{x: var * ?float32}"),
)

a_union = [[1, "hi", "bye"], None]
caa_union = ak.Array(a_union)[[0, 1], 1:2]
daa_union = dak.from_awkward(caa_union, npartitions=1)
assert_eq(
dak.enforce_type(daa_union, "?var * string"),
ak.enforce_type(caa_union, "?var * string"),
)
Loading