|
| 1 | +import itertools |
| 2 | +from collections.abc import Hashable, Mapping |
| 3 | +from functools import lru_cache |
| 4 | +from numbers import Number |
| 5 | +from typing import TYPE_CHECKING, Any |
| 6 | + |
| 7 | +from xarray.core import utils |
| 8 | +from xarray.core.utils import emit_user_level_warning |
| 9 | +from xarray.core.variable import IndexVariable, Variable |
| 10 | +from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, guess_chunkmanager |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from xarray.core.types import T_ChunkDim |
| 14 | + |
| 15 | + |
| 16 | +@lru_cache(maxsize=512) |
| 17 | +def _get_breaks_cached( |
| 18 | + *, |
| 19 | + size: int, |
| 20 | + chunk_sizes: tuple[int, ...], |
| 21 | + preferred_chunk_sizes: int | tuple[int, ...], |
| 22 | +) -> int | None: |
| 23 | + if isinstance(preferred_chunk_sizes, int) and preferred_chunk_sizes == 1: |
| 24 | + # short-circuit for the trivial case |
| 25 | + return None |
| 26 | + # Determine the stop indices of the preferred chunks, but omit the last stop |
| 27 | + # (equal to the dim size). In particular, assume that when a sequence |
| 28 | + # expresses the preferred chunks, the sequence sums to the size. |
| 29 | + preferred_stops = ( |
| 30 | + range(preferred_chunk_sizes, size, preferred_chunk_sizes) |
| 31 | + if isinstance(preferred_chunk_sizes, int) |
| 32 | + else set(itertools.accumulate(preferred_chunk_sizes[:-1])) |
| 33 | + ) |
| 34 | + |
| 35 | + # Gather any stop indices of the specified chunks that are not a stop index |
| 36 | + # of a preferred chunk. Again, omit the last stop, assuming that it equals |
| 37 | + # the dim size. |
| 38 | + actual_stops = itertools.accumulate(chunk_sizes[:-1]) |
| 39 | + # This copy is required for parallel iteration |
| 40 | + actual_stops_2 = itertools.accumulate(chunk_sizes[:-1]) |
| 41 | + |
| 42 | + disagrees = itertools.compress( |
| 43 | + actual_stops_2, (a not in preferred_stops for a in actual_stops) |
| 44 | + ) |
| 45 | + try: |
| 46 | + return next(disagrees) |
| 47 | + except StopIteration: |
| 48 | + return None |
| 49 | + |
| 50 | + |
| 51 | +def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): |
| 52 | + """ |
| 53 | + Return map from each dim to chunk sizes, accounting for backend's preferred chunks. |
| 54 | + """ |
| 55 | + if isinstance(var, IndexVariable): |
| 56 | + return {} |
| 57 | + dims = var.dims |
| 58 | + shape = var.shape |
| 59 | + |
| 60 | + # Determine the explicit requested chunks. |
| 61 | + preferred_chunks = var.encoding.get("preferred_chunks", {}) |
| 62 | + preferred_chunk_shape = tuple( |
| 63 | + preferred_chunks.get(dim, size) for dim, size in zip(dims, shape, strict=True) |
| 64 | + ) |
| 65 | + if isinstance(chunks, Number) or (chunks == "auto"): |
| 66 | + chunks = dict.fromkeys(dims, chunks) |
| 67 | + chunk_shape = tuple( |
| 68 | + chunks.get(dim, None) or preferred_chunk_sizes |
| 69 | + for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape, strict=True) |
| 70 | + ) |
| 71 | + |
| 72 | + chunk_shape = chunkmanager.normalize_chunks( |
| 73 | + chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape |
| 74 | + ) |
| 75 | + |
| 76 | + # Warn where requested chunks break preferred chunks, provided that the variable |
| 77 | + # contains data. |
| 78 | + if var.size: |
| 79 | + for dim, size, chunk_sizes in zip(dims, shape, chunk_shape, strict=True): |
| 80 | + try: |
| 81 | + preferred_chunk_sizes = preferred_chunks[dim] |
| 82 | + except KeyError: |
| 83 | + continue |
| 84 | + disagreement = _get_breaks_cached( |
| 85 | + size=size, |
| 86 | + chunk_sizes=chunk_sizes, |
| 87 | + preferred_chunk_sizes=preferred_chunk_sizes, |
| 88 | + ) |
| 89 | + if disagreement: |
| 90 | + emit_user_level_warning( |
| 91 | + "The specified chunks separate the stored chunks along " |
| 92 | + f'dimension "{dim}" starting at index {disagreement}. This could ' |
| 93 | + "degrade performance. Instead, consider rechunking after loading.", |
| 94 | + ) |
| 95 | + |
| 96 | + return dict(zip(dims, chunk_shape, strict=True)) |
| 97 | + |
| 98 | + |
| 99 | +def _maybe_chunk( |
| 100 | + name: Hashable, |
| 101 | + var: Variable, |
| 102 | + chunks: Mapping[Any, "T_ChunkDim"] | None, |
| 103 | + token=None, |
| 104 | + lock=None, |
| 105 | + name_prefix: str = "xarray-", |
| 106 | + overwrite_encoded_chunks: bool = False, |
| 107 | + inline_array: bool = False, |
| 108 | + chunked_array_type: str | ChunkManagerEntrypoint | None = None, |
| 109 | + from_array_kwargs=None, |
| 110 | +) -> Variable: |
| 111 | + from xarray.namedarray.daskmanager import DaskManager |
| 112 | + |
| 113 | + if chunks is not None: |
| 114 | + chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks} |
| 115 | + |
| 116 | + if var.ndim: |
| 117 | + chunked_array_type = guess_chunkmanager( |
| 118 | + chunked_array_type |
| 119 | + ) # coerce string to ChunkManagerEntrypoint type |
| 120 | + if isinstance(chunked_array_type, DaskManager): |
| 121 | + from dask.base import tokenize |
| 122 | + |
| 123 | + # when rechunking by different amounts, make sure dask names change |
| 124 | + # by providing chunks as an input to tokenize. |
| 125 | + # subtle bugs result otherwise. see GH3350 |
| 126 | + # we use str() for speed, and use the name for the final array name on the next line |
| 127 | + token2 = tokenize(token if token else var._data, str(chunks)) |
| 128 | + name2 = f"{name_prefix}{name}-{token2}" |
| 129 | + |
| 130 | + from_array_kwargs = utils.consolidate_dask_from_array_kwargs( |
| 131 | + from_array_kwargs, |
| 132 | + name=name2, |
| 133 | + lock=lock, |
| 134 | + inline_array=inline_array, |
| 135 | + ) |
| 136 | + |
| 137 | + var = var.chunk( |
| 138 | + chunks, |
| 139 | + chunked_array_type=chunked_array_type, |
| 140 | + from_array_kwargs=from_array_kwargs, |
| 141 | + ) |
| 142 | + |
| 143 | + if overwrite_encoded_chunks and var.chunks is not None: |
| 144 | + var.encoding["chunks"] = tuple(x[0] for x in var.chunks) |
| 145 | + return var |
| 146 | + else: |
| 147 | + return var |
0 commit comments