Skip to content

Commit 2e0166d

Browse files
committed
refactor: change Dask reader API again
Move `cfg=` into `.open` rather `.read`
1 parent 5ae49f0 commit 2e0166d

File tree

6 files changed

+75
-32
lines changed

6 files changed

+75
-32
lines changed

odc/loader/_builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def resolve_sources(
125125
out: List[List[tuple[int, RasterSource]]] = []
126126

127127
for layer in self.srcs:
128-
_srcs: List[RasterSource] = []
128+
_srcs: List[tuple[int, RasterSource]] = []
129129
for idx in layer:
130130
src = srcs[idx].get(self.band, None)
131131
if src is not None:
@@ -280,10 +280,12 @@ def _task_futures(
280280
src_hash = tokenize(src)
281281
rdr = rdr_cache.get(src_hash, None)
282282
if rdr is None:
283-
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
283+
rdr = dask_reader.open(
284+
src, cfg, ctx, layer_name=layer_name, idx=i_src
285+
)
284286
rdr_cache[src_hash] = rdr
285287

286-
fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
288+
fut = rdr.read(dst_gbox, selection=task.selection, idx=idx)
287289
keys_out.append(fut.key)
288290
dsk.update(fut.dask)
289291

odc/loader/_reader.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
env: dict[str, Any] | None = None,
5454
ctx: Any | None = None,
5555
src: RasterSource | None = None,
56+
cfg: RasterLoadParams | None = None,
5657
layer_name: str = "",
5758
idx: int = -1,
5859
) -> None:
@@ -63,26 +64,28 @@ def __init__(
6364
self._env = env
6465
self._ctx = ctx
6566
self._src = src
67+
self._cfg = cfg
6668
self._layer_name = layer_name
6769
self._src_idx = idx
6870

6971
def read(
7072
self,
71-
cfg: RasterLoadParams,
7273
dst_geobox: GeoBox,
7374
*,
7475
selection: Optional[ReaderSubsetSelection] = None,
7576
idx: tuple[int, ...],
7677
) -> Any:
7778
assert self._src is not None
7879
assert self._ctx is not None
80+
assert self._cfg is not None
81+
7982
read_op = delayed(_dask_read_adaptor, name=self._layer_name)
8083

8184
# TODO: supply `dask_key_name=` that makes sense
8285
return read_op(
8386
self._src,
8487
self._ctx,
85-
cfg,
88+
self._cfg,
8689
dst_geobox,
8790
self._driver,
8891
self._env,
@@ -91,13 +94,19 @@ def read(
9194
)
9295

9396
def open(
94-
self, src: RasterSource, ctx: Any, layer_name: str, idx: int
97+
self,
98+
src: RasterSource,
99+
cfg: RasterLoadParams,
100+
ctx: Any,
101+
layer_name: str,
102+
idx: int,
95103
) -> "ReaderDaskAdaptor":
96104
return ReaderDaskAdaptor(
97105
self._driver,
98106
self._env,
99107
ctx,
100108
src,
109+
cfg,
101110
layer_name=layer_name,
102111
idx=idx,
103112
)

odc/loader/_zarr.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dask.base import tokenize
2020
from dask.delayed import Delayed, delayed
2121
from fsspec.core import url_to_fs
22-
from odc.geo.geobox import GeoBox, GeoBoxBase
22+
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
2323
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject
2424

2525
from .types import (
@@ -106,9 +106,17 @@ def __init__(
106106
chunks: None | dict[str, int],
107107
driver: Any | None = None,
108108
) -> None:
109+
gbt: GeoboxTiles | None = None
110+
if chunks is not None:
111+
_chunks = tuple(
112+
chunks.get(name, fallback)
113+
for name, fallback in zip(["y", "x"], geobox.shape.yx)
114+
)
115+
gbt = GeoboxTiles(geobox, _chunks)
109116
self.geobox = geobox
110117
self.chunks = chunks
111118
self.driver = driver
119+
self.gbt = gbt
112120

113121
def with_env(self, env: dict[str, Any]) -> "Context":
114122
assert isinstance(env, dict)
@@ -246,38 +254,41 @@ class XrMemReaderDask:
246254
def __init__(
247255
self,
248256
src: xr.DataArray | None = None,
257+
cfg: RasterLoadParams | None = None,
249258
layer_name: str = "",
250259
) -> None:
251260
self._layer_name = layer_name
252261
self._xx = src
262+
self._cfg = cfg
253263

254264
def read(
255265
self,
256-
cfg: RasterLoadParams,
257266
dst_geobox: GeoBox,
258267
*,
259268
selection: ReaderSubsetSelection | None = None,
260269
idx: tuple[int, ...] = (),
261270
) -> Delayed:
262271
assert self._xx is not None
272+
assert self._cfg is not None
263273
assert isinstance(idx, tuple)
274+
xx = self._xx
275+
assert isinstance(xx.odc, ODCExtensionDa)
276+
assert isinstance(xx.odc.geobox, GeoBox)
264277

265-
xx = _select_extra_dims(self._xx, selection, cfg)
266-
assert xx.odc.geobox is not None
278+
yx_roi = xx.odc.geobox.overlap_roi(dst_geobox)
279+
selection = _extra_dims_selector(selection, self._cfg)
280+
selection.update({dim: sel for dim, sel in zip(xx.odc.spatial_dims, yx_roi)})
267281

268-
yy = xr_reproject(
269-
xx,
270-
dst_geobox,
271-
resampling=cfg.resampling,
272-
dst_nodata=cfg.fill_value,
273-
dtype=cfg.dtype,
274-
chunks=dst_geobox.shape.yx,
275-
)
276-
return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx))
282+
xx = self._xx.isel(selection)
283+
out_key = (self._layer_name, *idx)
284+
fut = delayed(_with_roi)(xx.data, dask_key_name=out_key)
285+
286+
return fut
277287

278288
def open(
279289
self,
280290
src: RasterSource,
291+
cfg: RasterLoadParams,
281292
ctx: Context,
282293
*,
283294
layer_name: str,
@@ -290,7 +301,19 @@ def open(
290301

291302
assert xx.odc.geobox is not None
292303
assert not any(map(math.isnan, xx.odc.geobox.transform[:6]))
293-
return XrMemReaderDask(xx, layer_name=layer_name)
304+
assert ctx.gbt is not None
305+
gbt = ctx.gbt
306+
307+
xx_warped = xr_reproject(
308+
xx,
309+
gbt.base,
310+
resampling=cfg.resampling,
311+
dst_nodata=cfg.fill_value,
312+
dtype=cfg.dtype,
313+
chunks=gbt.chunk_shape((0, 0)).yx,
314+
)
315+
316+
return XrMemReaderDask(xx_warped, cfg, layer_name=layer_name)
294317

295318

296319
class XrMemReaderDriver:
@@ -483,16 +506,25 @@ def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]:
483506
return (slice(None), slice(None)), xx
484507

485508

509+
def _extra_dims_selector(
510+
selection: ReaderSubsetSelection, cfg: RasterLoadParams
511+
) -> dict[str, Any]:
512+
if selection is None:
513+
return {}
514+
515+
assert isinstance(selection, (slice, int)) or len(selection) == 1
516+
assert len(cfg.extra_dims) == 1
517+
(band_dim,) = cfg.extra_dims
518+
return {band_dim: selection}
519+
520+
486521
def _select_extra_dims(
487522
src: xr.DataArray, selection: ReaderSubsetSelection, cfg: RasterLoadParams
488523
) -> xr.DataArray:
489524
if selection is None:
490525
return src
491526

492-
assert isinstance(selection, (slice, int)) or len(selection) == 1
493-
assert len(cfg.extra_dims) == 1
494-
(band_dim,) = cfg.extra_dims
495-
return src.isel({band_dim: selection})
527+
return src.isel(_extra_dims_selector(selection, cfg))
496528

497529

498530
def extract_zarr_spec(src: SomeDoc) -> ZarrSpecDict | None:

odc/loader/test_memreader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,12 @@ def test_memreader_zarr(sample_ds: xr.Dataset):
222222
ctx = driver.new_load(gbox, chunks={})
223223
assert isinstance(ctx, Context)
224224

225-
rdr = driver.dask_reader.open(src, ctx, layer_name=f"xx-{tk}", idx=0)
225+
rdr = driver.dask_reader.open(src, cfg, ctx, layer_name=f"xx-{tk}", idx=0)
226226
assert isinstance(rdr, XrMemReaderDask)
227227
assert rdr._xx is not None
228228
assert is_dask_collection(rdr._xx)
229229

230-
fut = rdr.read(cfg, gbox)
230+
fut = rdr.read(gbox)
231231
assert is_dask_collection(fut)
232232

233233
roi, xx = fut.compute(scheduler="synchronous")

odc/loader/test_reader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,16 +389,16 @@ def test_dask_reader_adaptor(dtype: str):
389389
ctx = base_driver.new_load(gbox, chunks={"x": 64, "y": 64})
390390

391391
src = RasterSource("mem://", meta=meta)
392-
rdr = driver.open(src, ctx, layer_name="aa", idx=0)
392+
cfg = RasterLoadParams.same_as(src)
393+
rdr = driver.open(src, cfg, ctx, layer_name="aa", idx=0)
393394

394395
assert isinstance(rdr, ReaderDaskAdaptor)
395396

396-
cfg = RasterLoadParams.same_as(src)
397-
xx = rdr.read(cfg, gbox, idx=(0,))
397+
xx = rdr.read(gbox, idx=(0,))
398398
assert is_dask_collection(xx)
399399
assert xx.key == ("aa", 0)
400-
assert rdr.read(cfg, gbox, idx=(1,)).key == ("aa", 1)
401-
assert rdr.read(cfg, gbox, idx=(1, 2, 3)).key == ("aa", 1, 2, 3)
400+
assert rdr.read(gbox, idx=(1,)).key == ("aa", 1)
401+
assert rdr.read(gbox, idx=(1, 2, 3)).key == ("aa", 1, 2, 3)
402402

403403
yy = xx.compute(scheduler="synchronous")
404404
assert isinstance(yy, tuple)

odc/loader/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,6 @@ class DaskRasterReader(Protocol):
471471

472472
def read(
473473
self,
474-
cfg: RasterLoadParams,
475474
dst_geobox: GeoBox,
476475
*,
477476
selection: Optional[ReaderSubsetSelection] = None,
@@ -481,6 +480,7 @@ def read(
481480
def open(
482481
self,
483482
src: RasterSource,
483+
cfg: RasterLoadParams,
484484
ctx: Any,
485485
*,
486486
layer_name: str,

0 commit comments

Comments
 (0)