Skip to content

Commit

Permalink
BUG: fix raising a unyt array to an array power in sensible cases
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Oct 8, 2024
1 parent 6c6650b commit 1fedbac
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,8 +1764,8 @@ def __pow__(self, p, mod=None, /):
Power function
"""
# see https://github.com/yt-project/unyt/issues/203
if p == 0.0:
ret = self.ua
if np.isscalar(p) and p == 0.0:
ret = self.unit_array
ret.units = Unit("dimensionless")
return ret
else:
Expand Down Expand Up @@ -1854,17 +1854,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
u1 = Unit(registry=getattr(u0, "registry", None))
elif ufunc is power:
u1 = inp1
if inp0.shape != () and inp1.shape != ():
raise UnitOperationError(ufunc, u0, u1)
if isinstance(u1, unyt_array):
if u1.units.is_dimensionless:
pass
else:
if inp0.shape == () or inp1.shape == ():
if isinstance(u1, unyt_array) and not u1.units.is_dimensionless:
raise UnitOperationError(ufunc, u0, u1.units)
if u1.shape == ():
u1 = float(u1)
if u1.shape == ():
u1 = float(u1)
else:
u1 = 1.0
elif inp0.shape == inp1.shape:
if (
isinstance(u1, unyt_array) and not u1.units.is_dimensionless
) or np.ptp(u1) != 0:
raise UnitOperationError(ufunc, u0, getattr(u1, "units", None))
first_element_slice = (0,) * u1.ndim
u1 = float(u1[first_element_slice])
else:
u1 = 1.0
raise UnitOperationError(ufunc, u0, u1)
unit_operator = self._ufunc_registry[ufunc]

if (
Expand Down

0 comments on commit 1fedbac

Please sign in to comment.