@@ -1764,8 +1764,8 @@ def __pow__(self, p, mod=None, /):
1764
1764
Power function
1765
1765
"""
1766
1766
# see https://github.com/yt-project/unyt/issues/203
1767
- if p == 0.0 :
1768
- ret = self .ua
1767
+ if np . isscalar ( p ) and p == 0.0 :
1768
+ ret = self .unit_array
1769
1769
ret .units = Unit ("dimensionless" )
1770
1770
return ret
1771
1771
else :
@@ -1854,17 +1854,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
1854
1854
u1 = Unit (registry = getattr (u0 , "registry" , None ))
1855
1855
elif ufunc is power :
1856
1856
u1 = inp1
1857
- if inp0 .shape != () and inp1 .shape != ():
1858
- raise UnitOperationError (ufunc , u0 , u1 )
1859
- if isinstance (u1 , unyt_array ):
1860
- if u1 .units .is_dimensionless :
1861
- pass
1862
- else :
1857
+ if inp0 .shape == () or inp1 .shape == ():
1858
+ if isinstance (u1 , unyt_array ) and not u1 .units .is_dimensionless :
1863
1859
raise UnitOperationError (ufunc , u0 , u1 .units )
1864
- if u1 .shape == ():
1865
- u1 = float (u1 )
1860
+ if u1 .shape == ():
1861
+ u1 = float (u1 )
1862
+ else :
1863
+ u1 = 1.0
1864
+ elif inp0 .shape == inp1 .shape :
1865
+ if (
1866
+ isinstance (u1 , unyt_array ) and not u1 .units .is_dimensionless
1867
+ ) or np .ptp (u1 ) != 0 :
1868
+ raise UnitOperationError (ufunc , u0 , getattr (u1 , "units" , None ))
1869
+ first_element_slice = (0 ,) * u1 .ndim
1870
+ u1 = float (u1 [first_element_slice ])
1866
1871
else :
1867
- u1 = 1.0
1872
+ raise UnitOperationError ( ufunc , u0 , u1 )
1868
1873
unit_operator = self ._ufunc_registry [ufunc ]
1869
1874
1870
1875
if (
0 commit comments