Skip to content

Commit 1fedbac

Browse files
committed
BUG: fix raising a unyt array to an array power in sensible cases
1 parent 6c6650b commit 1fedbac

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

unyt/array.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,8 +1764,8 @@ def __pow__(self, p, mod=None, /):
17641764
Power function
17651765
"""
17661766
# see https://github.com/yt-project/unyt/issues/203
1767-
if p == 0.0:
1768-
ret = self.ua
1767+
if np.isscalar(p) and p == 0.0:
1768+
ret = self.unit_array
17691769
ret.units = Unit("dimensionless")
17701770
return ret
17711771
else:
@@ -1854,17 +1854,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
18541854
u1 = Unit(registry=getattr(u0, "registry", None))
18551855
elif ufunc is power:
18561856
u1 = inp1
1857-
if inp0.shape != () and inp1.shape != ():
1858-
raise UnitOperationError(ufunc, u0, u1)
1859-
if isinstance(u1, unyt_array):
1860-
if u1.units.is_dimensionless:
1861-
pass
1862-
else:
1857+
if inp0.shape == () or inp1.shape == ():
1858+
if isinstance(u1, unyt_array) and not u1.units.is_dimensionless:
18631859
raise UnitOperationError(ufunc, u0, u1.units)
1864-
if u1.shape == ():
1865-
u1 = float(u1)
1860+
if u1.shape == ():
1861+
u1 = float(u1)
1862+
else:
1863+
u1 = 1.0
1864+
elif inp0.shape == inp1.shape:
1865+
if (
1866+
isinstance(u1, unyt_array) and not u1.units.is_dimensionless
1867+
) or np.ptp(u1) != 0:
1868+
raise UnitOperationError(ufunc, u0, getattr(u1, "units", None))
1869+
first_element_slice = (0,) * u1.ndim
1870+
u1 = float(u1[first_element_slice])
18661871
else:
1867-
u1 = 1.0
1872+
raise UnitOperationError(ufunc, u0, u1)
18681873
unit_operator = self._ufunc_registry[ufunc]
18691874

18701875
if (

0 commit comments

Comments
 (0)