Skip to content

Commit 954ab58

Browse files
Merge pull request #335 from neutrinoceros/manual_bps_2.9.3
REL: manual backports for release 2.9.3
2 parents dcb321c + b8a380a commit 954ab58

File tree

5 files changed

+83
-14
lines changed

5 files changed

+83
-14
lines changed

HISTORY.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
History
33
=======
44

5+
2.9.3 (2022-12-07)
6+
------------------
7+
8+
* Fix a future incompatibility with numpy 1.25 (unreleased) where comparing
9+
``unyt_array`` objects to non-numeric objects (e.g. strings) would cause a
10+
crash. See `PR #333 <https://github.com/yt-project/unyt/pull/333>`_. Thank you
11+
to Clément Robert (@neutrinoceros on GitHub) and Nathan Goldbaum (@ngoldbaum
12+
on GitHub) the contribution.
13+
514
2.9.2 (2022-07-20)
615
------------------
716

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
[build-system]
2+
requires = [
3+
"setuptools==61.2",
4+
"setuptools_scm[toml]==6.2",
5+
]
6+
7+
build-backend = "setuptools.build_meta"
8+
19
[tool.black]
210
target-version = ['py36']
311
line-length = 88

unyt/array.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141

142142
NULL_UNIT = Unit()
143143
POWER_MAPPING = {multiply: lambda x: x, divide: lambda x: 2 - x}
144+
DISALLOWED_DTYPES = ("S", "U", "a", "O", "M", "m", "b")
144145

145146
__doctest_requires__ = {
146147
("unyt_array.from_pint", "unyt_array.to_pint"): ["pint"],
@@ -253,21 +254,28 @@ def _bitop_units(unit1, unit2):
253254

254255
def _coerce_iterable_units(input_object, registry=None):
255256
if isinstance(input_object, np.ndarray):
256-
return input_object
257-
if _iterable(input_object):
258-
if any([isinstance(o, unyt_array) for o in input_object]):
257+
ret = input_object
258+
elif _iterable(input_object):
259+
if any(isinstance(o, unyt_array) for o in input_object):
259260
ff = getattr(input_object[0], "units", NULL_UNIT)
260-
if any([ff != getattr(_, "units", NULL_UNIT) for _ in input_object]):
261+
if any(ff != getattr(_, "units", NULL_UNIT) for _ in input_object):
261262
ret = []
262263
for datum in input_object:
263264
try:
264265
ret.append(datum.in_units(ff.units))
265266
except UnitConversionError:
266267
raise IterableUnitCoercionError(str(input_object))
267-
return unyt_array(np.array(ret), ff, registry=registry)
268+
ret = unyt_array(np.array(ret), ff, registry=registry)
268269
# This will create a copy of the data in the iterable.
269-
return unyt_array(np.array(input_object), ff, registry=registry)
270-
return np.asarray(input_object)
270+
else:
271+
ret = unyt_array(np.array(input_object), ff, registry=registry)
272+
else:
273+
ret = np.asarray(input_object)
274+
else:
275+
ret = np.asarray(input_object)
276+
if ret.dtype.char in DISALLOWED_DTYPES:
277+
raise IterableUnitCoercionError(str(input_object))
278+
return ret
271279

272280

273281
def _sanitize_units_convert(possible_units, registry):
@@ -1717,6 +1725,18 @@ def __pow__(self, p, mod=None, /):
17171725
else:
17181726
return super().__pow__(p, mod)
17191727

1728+
def __eq__(self, other):
1729+
try:
1730+
return super().__eq__(other)
1731+
except (IterableUnitCoercionError, UnitOperationError):
1732+
return np.zeros(self.shape, dtype="bool")
1733+
1734+
def __ne__(self, other):
1735+
try:
1736+
return super().__ne__(other)
1737+
except (IterableUnitCoercionError, UnitOperationError):
1738+
return np.ones(self.shape, dtype="bool")
1739+
17201740
#
17211741
# Start operation methods
17221742
#

unyt/exceptions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,18 +167,19 @@ class IterableUnitCoercionError(Exception):
167167
# doctest: +IGNORE_EXCEPTION_DETAIL +NORMALIZE_WHITESPACE
168168
Traceback (most recent call last):
169169
...
170-
unyt.exceptions.IterableUnitCoercionError: Received a list or
171-
tuple of quantities with nonuniform units:
170+
unyt.exceptions.IterableUnitCoercionError: Received an input
171+
or operand that cannot be converted to a unyt_array with uniform
172+
units:
172173
[unyt_quantity(2., 'cm'), unyt_quantity(3., 'g')]
173174
"""
174175

175-
def __init__(self, quantity_list):
176-
self.quantity_list = quantity_list
176+
def __init__(self, op):
177+
self.op = op
177178

178179
def __str__(self):
179180
err = (
180-
"Received a list or tuple of quantities with nonuniform units: "
181-
"%s" % self.quantity_list
181+
"Received an input or operand that cannot be converted "
182+
f"to a unyt_array with uniform units: {self.op}"
182183
)
183184
return err
184185

unyt/tests/test_unyt_array.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1483,8 +1483,31 @@ def op_comparison(op, inst1, inst2, compare_class):
14831483
assert_isinstance(a.copy(), unyt_a_subclass)
14841484
assert_isinstance(copy.deepcopy(a), unyt_a_subclass)
14851485

1486-
with pytest.raises(RuntimeError):
1486+
1487+
def test_string_operations_raise_errors():
1488+
a = unyt_array([1, 2, 3], "g")
1489+
with pytest.raises(IterableUnitCoercionError):
14871490
a + "hello"
1491+
with pytest.raises(IterableUnitCoercionError):
1492+
a * "hello"
1493+
with pytest.raises(IterableUnitCoercionError):
1494+
a ** "hello"
1495+
if Version(np.__version__) < Version("1.24"):
1496+
with pytest.warns(FutureWarning):
1497+
assert a != "hello"
1498+
else:
1499+
assert (a != "hello").all()
1500+
1501+
1502+
def test_string_operations_raise_errors_quantity():
1503+
q = 2 * g
1504+
with pytest.raises(IterableUnitCoercionError):
1505+
q + "hello"
1506+
with pytest.raises(IterableUnitCoercionError):
1507+
q * "hello"
1508+
with pytest.raises(IterableUnitCoercionError):
1509+
q ** "hello"
1510+
assert q != "hello"
14881511

14891512

14901513
def test_h5_io():
@@ -2681,3 +2704,11 @@ def test_reshape_quantity_via_shape_tuple():
26812704
b = a.reshape(-1, 1)
26822705
assert b.shape == (1, 1)
26832706
assert type(b) is unyt_array
2707+
2708+
2709+
def test_string_comparison():
2710+
# exercise comparison between a unyt_quantity object and a string
2711+
# see regression https://github.com/numpy/numpy/issues/22744
2712+
a = 1 * cm
2713+
assert not (a == "hello")
2714+
assert a != "hello"

0 commit comments

Comments
 (0)