diff --git a/unyt/dask_array.py b/unyt/dask_array.py index 224b9a12..f4ca6e26 100644 --- a/unyt/dask_array.py +++ b/unyt/dask_array.py @@ -8,20 +8,13 @@ from functools import wraps import numpy as np +import pytest import unyt.array as ua -from unyt._on_demand_imports import _dask as dask -__doctest_requires__ = { - ("unyt_from_dask", "reduce_with_units", "unyt_dask_array.to_dask"): ["dask"], -} - - -if dask.__is_available__: - _dask_Array = dask.array.core.Array - _dask_finalize = dask.array.core.finalize -else: - _dask_Array, _dask_finalize = object, None +pytest.importorskip("dask") +del pytest +from dask.array.core import Array as DaskArray, finalize as dask_finalize # noqa: E402 # the following attributes hang off of dask.array.core.Array and do not modify units _use_unary_decorator = { @@ -191,7 +184,7 @@ def wrapper(self, *args, **kwargs): return wrapper -class unyt_dask_array(_dask_Array): +class unyt_dask_array(DaskArray): """ a dask.array.core.Array subclass that tracks units. This class is only recommended for advanced usage, most cases should use the unyt_from_dask @@ -339,7 +332,7 @@ def to_dask(self): chunksize=(1000, 1000), chunktype=numpy.ndarray> """ (_, args) = super().__reduce__() - return _dask_Array(*args) + return DaskArray(*args) def __reduce__(self): (_, args) = super().__reduce__() @@ -498,7 +491,7 @@ def _finalize_unyt(results, unit_name): # here, we first call the standard finalize function for a dask array # and then return a standard unyt_array from the now in-memory result if # the result is an array, otherwise return a unyt_quantity. - result = _dask_finalize(results) + result = dask_finalize(results) if type(result) == np.ndarray: return ua.unyt_array(result, unit_name) @@ -656,10 +649,11 @@ def reduce_with_units(dask_func, unyt_dask_in, *args, **kwargs): Examples -------- - >>> from unyt import dask_array - >>> a = dask_array.dask.array.ones((10000,), chunks=(100,)) - >>> a = dask_array.unyt_from_dask(a, 'm') - >>> b = dask_array.reduce_with_units(dask_array.dask.array.median, a, axis=0) + >>> import dask.array + >>> from unyt.dask_array import unyt_from_dask, reduce_with_units + >>> a = dask.array.ones((10000,), chunks=(100,)) + >>> a = unyt_from_dask(a, 'm') + >>> b = reduce_with_units(dask.array.median, a, axis=0) >>> b.compute() unyt_quantity(1., 'm')