diff --git a/docs/usage.rst b/docs/usage.rst index b7cb53c8..aabb3e72 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -959,14 +959,14 @@ bit floating point array. >>> data = np.array([1, 2, 3], dtype='int32')*km >>> data.in_units('mile') - unyt_array([0.62137121, 1.24274242, 1.86411357], dtype=float32, units='mile') + unyt_array([0.6213712, 1.2427424, 1.8641136], dtype=float32, units='mile') In-place operations will also mutate the dtype from float to integer in these cases, again in a way that will preserve the byte size of the data. >>> data.convert_to_units('mile') >>> data - unyt_array([0.62137121, 1.24274242, 1.86411357], dtype=float32, units='mile') + unyt_array([0.6213712, 1.2427424, 1.8641136], dtype=float32, units='mile') It is possible that arrays containing large integers (16777217 for 32 bit and 9007199254740993 for 64 bit) will lose precision when converting data to a diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 57a22e62..ccb87bf1 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -995,3 +995,14 @@ def interp(x, xp, fp, *args, **kwargs): np.interp(np.asarray(x), np.asarray(xp), np.asarray(fp), *args, **kwargs) * ret_units ) + + +@implements(np.array_repr) +def array_repr(arr, *args, **kwargs): + rep = np.array_repr._implementation(arr.view(np.ndarray), *args, **kwargs) + rep = rep.replace("array", arr.__class__.__name__) + units_repr = arr.units.__repr__() + if "=" in rep: + return rep[:-1] + ", units='" + units_repr + "')" + else: + return rep[:-1] + ", '" + units_repr + "')" diff --git a/unyt/array.py b/unyt/array.py index ef06fb41..367ed145 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -636,12 +636,7 @@ def __new__( return obj def __repr__(self): - rep = super().__repr__() - units_repr = self.units.__repr__() - if "=" in rep: - return rep[:-1] + ", units='" + units_repr + "')" - else: - return rep[:-1] + ", '" + units_repr + "')" + return np.array_repr(self) def __str__(self): return str(self.view(np.ndarray)) + " " + str(self.units) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index d1d8cfeb..6d9ba584 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -35,7 +35,6 @@ np.argpartition, # returns pure numbers np.argsort, # returns pure numbers np.argwhere, # returns pure numbers - np.array_repr, # hooks into __repr__ np.array_str, # hooks into __str__ np.atleast_1d, # works out of the box (tested) np.atleast_2d, # works out of the box (tested) @@ -256,7 +255,7 @@ def test_wrapping_completeness(): def test_array_repr(): arr = [1, 2, 3] * cm - assert np.array_repr(arr) == "unyt_array([1, 2, 3], units='cm')" + assert np.array_repr(arr) == "unyt_array([1, 2, 3], 'cm')" def test_dot_vectors():