Skip to content

Commit

Permalink
Merge pull request #335 from neutrinoceros/manual_bps_2.9.3
Browse files Browse the repository at this point in the history
REL: manual backports for release 2.9.3
  • Loading branch information
neutrinoceros authored Dec 7, 2022
2 parents dcb321c + b8a380a commit 954ab58
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
9 changes: 9 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
History
=======

2.9.3 (2022-12-07)
------------------

* Fix a future incompatibility with numpy 1.25 (unreleased) where comparing
``unyt_array`` objects to non-numeric objects (e.g. strings) would cause a
crash. See `PR #333 <https://github.com/yt-project/unyt/pull/333>`_. Thank you
to Clément Robert (@neutrinoceros on GitHub) and Nathan Goldbaum (@ngoldbaum
on GitHub) the contribution.

2.9.2 (2022-07-20)
------------------

Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
[build-system]
requires = [
"setuptools==61.2",
"setuptools_scm[toml]==6.2",
]

build-backend = "setuptools.build_meta"

[tool.black]
target-version = ['py36']
line-length = 88
Expand Down
34 changes: 27 additions & 7 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@

NULL_UNIT = Unit()
POWER_MAPPING = {multiply: lambda x: x, divide: lambda x: 2 - x}
DISALLOWED_DTYPES = ("S", "U", "a", "O", "M", "m", "b")

__doctest_requires__ = {
("unyt_array.from_pint", "unyt_array.to_pint"): ["pint"],
Expand Down Expand Up @@ -253,21 +254,28 @@ def _bitop_units(unit1, unit2):

def _coerce_iterable_units(input_object, registry=None):
if isinstance(input_object, np.ndarray):
return input_object
if _iterable(input_object):
if any([isinstance(o, unyt_array) for o in input_object]):
ret = input_object
elif _iterable(input_object):
if any(isinstance(o, unyt_array) for o in input_object):
ff = getattr(input_object[0], "units", NULL_UNIT)
if any([ff != getattr(_, "units", NULL_UNIT) for _ in input_object]):
if any(ff != getattr(_, "units", NULL_UNIT) for _ in input_object):
ret = []
for datum in input_object:
try:
ret.append(datum.in_units(ff.units))
except UnitConversionError:
raise IterableUnitCoercionError(str(input_object))
return unyt_array(np.array(ret), ff, registry=registry)
ret = unyt_array(np.array(ret), ff, registry=registry)
# This will create a copy of the data in the iterable.
return unyt_array(np.array(input_object), ff, registry=registry)
return np.asarray(input_object)
else:
ret = unyt_array(np.array(input_object), ff, registry=registry)
else:
ret = np.asarray(input_object)
else:
ret = np.asarray(input_object)
if ret.dtype.char in DISALLOWED_DTYPES:
raise IterableUnitCoercionError(str(input_object))
return ret


def _sanitize_units_convert(possible_units, registry):
Expand Down Expand Up @@ -1717,6 +1725,18 @@ def __pow__(self, p, mod=None, /):
else:
return super().__pow__(p, mod)

def __eq__(self, other):
try:
return super().__eq__(other)
except (IterableUnitCoercionError, UnitOperationError):
return np.zeros(self.shape, dtype="bool")

def __ne__(self, other):
try:
return super().__ne__(other)
except (IterableUnitCoercionError, UnitOperationError):
return np.ones(self.shape, dtype="bool")

#
# Start operation methods
#
Expand Down
13 changes: 7 additions & 6 deletions unyt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,19 @@ class IterableUnitCoercionError(Exception):
# doctest: +IGNORE_EXCEPTION_DETAIL +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
unyt.exceptions.IterableUnitCoercionError: Received a list or
tuple of quantities with nonuniform units:
unyt.exceptions.IterableUnitCoercionError: Received an input
or operand that cannot be converted to a unyt_array with uniform
units:
[unyt_quantity(2., 'cm'), unyt_quantity(3., 'g')]
"""

def __init__(self, quantity_list):
self.quantity_list = quantity_list
def __init__(self, op):
self.op = op

def __str__(self):
err = (
"Received a list or tuple of quantities with nonuniform units: "
"%s" % self.quantity_list
"Received an input or operand that cannot be converted "
f"to a unyt_array with uniform units: {self.op}"
)
return err

Expand Down
33 changes: 32 additions & 1 deletion unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,31 @@ def op_comparison(op, inst1, inst2, compare_class):
assert_isinstance(a.copy(), unyt_a_subclass)
assert_isinstance(copy.deepcopy(a), unyt_a_subclass)

with pytest.raises(RuntimeError):

def test_string_operations_raise_errors():
a = unyt_array([1, 2, 3], "g")
with pytest.raises(IterableUnitCoercionError):
a + "hello"
with pytest.raises(IterableUnitCoercionError):
a * "hello"
with pytest.raises(IterableUnitCoercionError):
a ** "hello"
if Version(np.__version__) < Version("1.24"):
with pytest.warns(FutureWarning):
assert a != "hello"
else:
assert (a != "hello").all()


def test_string_operations_raise_errors_quantity():
q = 2 * g
with pytest.raises(IterableUnitCoercionError):
q + "hello"
with pytest.raises(IterableUnitCoercionError):
q * "hello"
with pytest.raises(IterableUnitCoercionError):
q ** "hello"
assert q != "hello"


def test_h5_io():
Expand Down Expand Up @@ -2681,3 +2704,11 @@ def test_reshape_quantity_via_shape_tuple():
b = a.reshape(-1, 1)
assert b.shape == (1, 1)
assert type(b) is unyt_array


def test_string_comparison():
# exercise comparison between a unyt_quantity object and a string
# see regression https://github.com/numpy/numpy/issues/22744
a = 1 * cm
assert not (a == "hello")
assert a != "hello"

0 comments on commit 954ab58

Please sign in to comment.