Skip to content

Commit db6cece

Browse files
tcawlfieldianna
andauthored
fix: ak.from_numpy should fail on zero-dimensional arrays. (#3161)
* Fixing policy issue 1057 At this point it's a matter of consistency. Now *both* `ak.Array()` and `ak.from_numpy()` will throw a TypeError when passed either a python scalar, a numpy numeric value like `np.float64(5.2)`, **or** a numpy zero-dimensional-array like `np.array(5.2)`. But in the latter case there is an option to allow this, which is necessary for lots of internal functions, by providing a value for `primitive_policy` other than "error." Prior to this patch the first test, ak.Array(), was already passing without modifications. But the second test was not -- ak.from_numpy(). Fix is in _layout.py. No policy options are passed to _layout.from_arraylib(). * Adding/passing a primitive_policy kwarg From from_numpy, from_cupy, from_jax, and from_dlpack To from_arraylib --------- Co-authored-by: Ianna Osborne <[email protected]>
1 parent 03f6169 commit db6cece

File tree

7 files changed

+74
-10
lines changed

7 files changed

+74
-10
lines changed

src/awkward/_layout.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ def maybe_highlevel_to_lowlevel(obj):
250250
return obj
251251

252252

253-
def from_arraylib(array, regulararray, recordarray):
253+
def from_arraylib(
254+
array,
255+
regulararray,
256+
recordarray,
257+
primitive_policy: Literal["error", "promote", "pass-through"] = "promote",
258+
):
254259
from awkward.contents import (
255260
ByteMaskedArray,
256261
ListArray,
@@ -341,6 +346,11 @@ def attach(x):
341346
if array.dtype == np.dtype("O"):
342347
raise TypeError("Awkward Array does not support arrays with object dtypes.")
343348

349+
if primitive_policy == "error" and array.ndim == 0:
350+
raise TypeError(
351+
f"Encountered a scalar ({type(array).__name__}), but scalar conversion/promotion is disabled"
352+
)
353+
344354
if isinstance(array, numpy.ma.MaskedArray):
345355
mask = numpy.ma.getmask(array)
346356
array = numpy.ma.getdata(array)

src/awkward/operations/ak_from_cupy.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99

1010

1111
@high_level_function()
12-
def from_cupy(array, *, regulararray=False, highlevel=True, behavior=None, attrs=None):
12+
def from_cupy(
13+
array,
14+
*,
15+
regulararray=False,
16+
highlevel=True,
17+
behavior=None,
18+
primitive_policy="error",
19+
attrs=None,
20+
):
1321
"""
1422
Args:
1523
array (cp.ndarray): The CuPy array to convert into an Awkward Array.
@@ -36,7 +44,7 @@ def from_cupy(array, *, regulararray=False, highlevel=True, behavior=None, attrs
3644
See also #ak.to_cupy, #ak.from_numpy and #ak.from_jax.
3745
"""
3846
return wrap_layout(
39-
from_arraylib(array, regulararray, False),
47+
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
4048
highlevel=highlevel,
4149
behavior=behavior,
4250
)

src/awkward/operations/ak_from_dlpack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def from_dlpack(
2020
regulararray=False,
2121
highlevel=True,
2222
behavior=None,
23+
primitive_policy="error",
2324
attrs=None,
2425
):
2526
"""
@@ -77,7 +78,7 @@ def from_dlpack(
7778

7879
array = nplike.from_dlpack(array)
7980
return wrap_layout(
80-
from_arraylib(array, regulararray, False),
81+
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
8182
highlevel=highlevel,
8283
behavior=behavior,
8384
)

src/awkward/operations/ak_from_jax.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010

1111

1212
@high_level_function()
13-
def from_jax(array, *, regulararray=False, highlevel=True, behavior=None, attrs=None):
13+
def from_jax(
14+
array,
15+
*,
16+
regulararray=False,
17+
highlevel=True,
18+
behavior=None,
19+
attrs=None,
20+
primitive_policy="error",
21+
):
1422
"""
1523
Args:
1624
array (jax.numpy.DeviceArray): The JAX DeviceArray to convert into an Awkward Array.
@@ -38,7 +46,7 @@ def from_jax(array, *, regulararray=False, highlevel=True, behavior=None, attrs=
3846
"""
3947
jax.assert_registered()
4048
return wrap_layout(
41-
from_arraylib(array, regulararray, False),
49+
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
4250
highlevel=highlevel,
4351
behavior=behavior,
4452
)

src/awkward/operations/ak_from_numpy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def from_numpy(
1616
recordarray=True,
1717
highlevel=True,
1818
behavior=None,
19+
primitive_policy="error",
1920
attrs=None,
2021
):
2122
"""
@@ -52,7 +53,9 @@ def from_numpy(
5253
See also #ak.to_numpy and #ak.from_cupy.
5354
"""
5455
return wrap_layout(
55-
from_arraylib(array, regulararray, recordarray),
56+
from_arraylib(
57+
array, regulararray, recordarray, primitive_policy=primitive_policy
58+
),
5659
highlevel=highlevel,
5760
behavior=behavior,
5861
)

src/awkward/operations/ak_to_layout.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,27 @@ def _impl(
179179
regulararray=regulararray,
180180
recordarray=True,
181181
highlevel=False,
182+
primitive_policy=primitive_policy,
182183
)
183184
return _handle_array_like(
184185
obj, promoted_layout, primitive_policy=primitive_policy
185186
)
186187
elif Cupy.is_own_array(obj):
187188
promoted_layout = ak.operations.from_cupy(
188-
obj, regulararray=regulararray, highlevel=False
189+
obj,
190+
regulararray=regulararray,
191+
highlevel=False,
192+
primitive_policy=primitive_policy,
189193
)
190194
return _handle_array_like(
191195
obj, promoted_layout, primitive_policy=primitive_policy
192196
)
193197
elif Jax.is_own_array(obj):
194198
promoted_layout = ak.operations.from_jax(
195-
obj, regulararray=regulararray, highlevel=False
199+
obj,
200+
regulararray=regulararray,
201+
highlevel=False,
202+
primitive_policy=primitive_policy,
196203
)
197204
return _handle_array_like(
198205
obj, promoted_layout, primitive_policy=primitive_policy
@@ -215,14 +222,17 @@ def _impl(
215222
elif ak._util.in_module(obj, "pyarrow"):
216223
return ak.operations.from_arrow(obj, highlevel=False)
217224
elif hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"):
218-
return ak.operations.from_dlpack(obj, highlevel=False)
225+
return ak.operations.from_dlpack(
226+
obj, highlevel=False, primitive_policy=primitive_policy
227+
)
219228
# Typed scalars
220229
elif isinstance(obj, np.generic):
221230
promoted_layout = ak.operations.from_numpy(
222231
numpy.asarray(obj),
223232
regulararray=regulararray,
224233
recordarray=True,
225234
highlevel=False,
235+
primitive_policy=primitive_policy,
226236
)
227237
return _handle_array_like(
228238
obj, promoted_layout, primitive_policy=primitive_policy
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
import pytest
7+
8+
import awkward as ak
9+
10+
11+
def test_akarray_from_zero_dim_nparray():
12+
np_scalar = np.array(2.7) # A kind of scalar in numpy.
13+
assert np_scalar.ndim == 0 and np_scalar.shape == ()
14+
with pytest.raises(TypeError):
15+
# Conversion to ak.Array ought to throw here:
16+
b = ak.Array(np_scalar) # (bugged) value: <Array [2.7] type='1 * int64'>
17+
# Now we're failing. Here's why.
18+
c = ak.to_numpy(b) # value: array([2.7])
19+
assert np_scalar.shape == c.shape # this fails
20+
21+
with pytest.raises(TypeError):
22+
b = ak.from_numpy(np_scalar)
23+
c = ak.to_numpy(b)
24+
assert np_scalar.shape == c.shape

0 commit comments

Comments
 (0)