1919from dask .base import tokenize
2020from dask .delayed import Delayed , delayed
2121from fsspec .core import url_to_fs
22- from odc .geo .geobox import GeoBox , GeoBoxBase
22+ from odc .geo .geobox import GeoBox , GeoBoxBase , GeoboxTiles
2323from odc .geo .xr import ODCExtensionDa , ODCExtensionDs , xr_coords , xr_reproject
2424
2525from .types import (
@@ -106,9 +106,17 @@ def __init__(
106106 chunks : None | dict [str , int ],
107107 driver : Any | None = None ,
108108 ) -> None :
109+ gbt : GeoboxTiles | None = None
110+ if chunks is not None :
111+ _chunks = tuple (
112+ chunks .get (name , fallback )
113+ for name , fallback in zip (["y" , "x" ], geobox .shape .yx )
114+ )
115+ gbt = GeoboxTiles (geobox , _chunks )
109116 self .geobox = geobox
110117 self .chunks = chunks
111118 self .driver = driver
119+ self .gbt = gbt
112120
113121 def with_env (self , env : dict [str , Any ]) -> "Context" :
114122 assert isinstance (env , dict )
@@ -246,38 +254,41 @@ class XrMemReaderDask:
246254 def __init__ (
247255 self ,
248256 src : xr .DataArray | None = None ,
257+ cfg : RasterLoadParams | None = None ,
249258 layer_name : str = "" ,
250259 ) -> None :
251260 self ._layer_name = layer_name
252261 self ._xx = src
262+ self ._cfg = cfg
253263
254264 def read (
255265 self ,
256- cfg : RasterLoadParams ,
257266 dst_geobox : GeoBox ,
258267 * ,
259268 selection : ReaderSubsetSelection | None = None ,
260269 idx : tuple [int , ...] = (),
261270 ) -> Delayed :
262271 assert self ._xx is not None
272+ assert self ._cfg is not None
263273 assert isinstance (idx , tuple )
274+ xx = self ._xx
275+ assert isinstance (xx .odc , ODCExtensionDa )
276+ assert isinstance (xx .odc .geobox , GeoBox )
264277
265- xx = _select_extra_dims (self ._xx , selection , cfg )
266- assert xx .odc .geobox is not None
278+ yx_roi = xx .odc .geobox .overlap_roi (dst_geobox )
279+ selection = _extra_dims_selector (selection , self ._cfg )
280+ selection .update ({dim : sel for dim , sel in zip (xx .odc .spatial_dims , yx_roi )})
267281
268- yy = xr_reproject (
269- xx ,
270- dst_geobox ,
271- resampling = cfg .resampling ,
272- dst_nodata = cfg .fill_value ,
273- dtype = cfg .dtype ,
274- chunks = dst_geobox .shape .yx ,
275- )
276- return delayed (_with_roi )(yy .data , dask_key_name = (self ._layer_name , * idx ))
282+ xx = self ._xx .isel (selection )
283+ out_key = (self ._layer_name , * idx )
284+ fut = delayed (_with_roi )(xx .data , dask_key_name = out_key )
285+
286+ return fut
277287
278288 def open (
279289 self ,
280290 src : RasterSource ,
291+ cfg : RasterLoadParams ,
281292 ctx : Context ,
282293 * ,
283294 layer_name : str ,
@@ -290,7 +301,19 @@ def open(
290301
291302 assert xx .odc .geobox is not None
292303 assert not any (map (math .isnan , xx .odc .geobox .transform [:6 ]))
293- return XrMemReaderDask (xx , layer_name = layer_name )
304+ assert ctx .gbt is not None
305+ gbt = ctx .gbt
306+
307+ xx_warped = xr_reproject (
308+ xx ,
309+ gbt .base ,
310+ resampling = cfg .resampling ,
311+ dst_nodata = cfg .fill_value ,
312+ dtype = cfg .dtype ,
313+ chunks = gbt .chunk_shape ((0 , 0 )).yx ,
314+ )
315+
316+ return XrMemReaderDask (xx_warped , cfg , layer_name = layer_name )
294317
295318
296319class XrMemReaderDriver :
@@ -483,16 +506,25 @@ def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]:
483506 return (slice (None ), slice (None )), xx
484507
485508
509+ def _extra_dims_selector (
510+ selection : ReaderSubsetSelection , cfg : RasterLoadParams
511+ ) -> dict [str , Any ]:
512+ if selection is None :
513+ return {}
514+
515+ assert isinstance (selection , (slice , int )) or len (selection ) == 1
516+ assert len (cfg .extra_dims ) == 1
517+ (band_dim ,) = cfg .extra_dims
518+ return {band_dim : selection }
519+
520+
486521def _select_extra_dims (
487522 src : xr .DataArray , selection : ReaderSubsetSelection , cfg : RasterLoadParams
488523) -> xr .DataArray :
489524 if selection is None :
490525 return src
491526
492- assert isinstance (selection , (slice , int )) or len (selection ) == 1
493- assert len (cfg .extra_dims ) == 1
494- (band_dim ,) = cfg .extra_dims
495- return src .isel ({band_dim : selection })
527+ return src .isel (_extra_dims_selector (selection , cfg ))
496528
497529
498530def extract_zarr_spec (src : SomeDoc ) -> ZarrSpecDict | None :
0 commit comments