Skip to content

Commit 5ae49f0

Browse files
committed
fix: In dask builder re-use readers
- Make sure to call dask_read.open() only once per RasterSource object. - use actual src_idx, not index within a tile - use tokenize(RasterSource) as cache key
1 parent aed76a4 commit 5ae49f0

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

odc/loader/_builder.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ def __bool__(self) -> bool:
121121

122122
def resolve_sources(
123123
self, srcs: Sequence[MultiBandRasterSource]
124-
) -> List[List[RasterSource]]:
125-
out: List[List[RasterSource]] = []
124+
) -> List[List[tuple[int, RasterSource]]]:
125+
out: List[List[tuple[int, RasterSource]]] = []
126126

127127
for layer in self.srcs:
128128
_srcs: List[RasterSource] = []
129129
for idx in layer:
130130
src = srcs[idx].get(self.band, None)
131131
if src is not None:
132-
_srcs.append(src)
132+
_srcs.append((idx, src))
133133
out.append(_srcs)
134134
return out
135135

@@ -263,6 +263,7 @@ def _task_futures(
263263
dask_reader: DaskRasterReader,
264264
layer_name: str,
265265
dsk: dict[Key, Any],
266+
rdr_cache: dict[str, DaskRasterReader],
266267
) -> list[list[Key]]:
267268
# pylint: disable=too-many-locals
268269
srcs = task.resolve_sources(self.srcs)
@@ -273,9 +274,15 @@ def _task_futures(
273274

274275
for i_time, layer in enumerate(srcs, start=task.idx[0]):
275276
keys_out: list[Key] = []
276-
for i_src, src in enumerate(layer):
277-
idx = (i_time, *task.idx[1:], i_src)
278-
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
277+
for i_src, src in layer:
278+
idx = (i_src, i_time, *task.idx[1:])
279+
280+
src_hash = tokenize(src)
281+
rdr = rdr_cache.get(src_hash, None)
282+
if rdr is None:
283+
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
284+
rdr_cache[src_hash] = rdr
285+
279286
fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
280287
keys_out.append(fut.key)
281288
dsk.update(fut.dask)
@@ -340,6 +347,7 @@ def __call__(
340347
cfg,
341348
resolve_src_nodata(cfg.fill_value, cfg),
342349
)
350+
rdr_cache: dict[str, DaskRasterReader] = {}
343351

344352
for task in self.load_tasks(name, shape[0]):
345353
task_key: Key = (band_layer, *task.idx)
@@ -359,7 +367,11 @@ def __call__(
359367
)
360368
else:
361369
srcs_futures = self._task_futures(
362-
task, dask_reader, open_layer, layers[open_layer]
370+
task,
371+
dask_reader,
372+
open_layer,
373+
layers[open_layer],
374+
rdr_cache=rdr_cache,
363375
)
364376

365377
dsk[task_key] = (
@@ -812,7 +824,7 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
812824

813825
with rdr.restore_env(env, load_state) as ctx:
814826
for t_idx, layer in enumerate(layers):
815-
loaders = [rdr.open(src, ctx) for src in layer]
827+
loaders = [rdr.open(src, ctx) for _, src in layer]
816828
_ = _fill_nd_slice(
817829
loaders,
818830
task.dst_gbox,

odc/loader/_zarr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,9 @@ def open(
283283
layer_name: str,
284284
idx: int,
285285
) -> DaskRasterReader:
286+
assert idx >= 0
286287
base, *_ = layer_name.rsplit("-", 1)
287-
_tk = tokenize(layer_name, idx)
288+
_tk = tokenize(src, ctx)
288289
xx = from_raster_source(src, ctx, name=f"{base}-zarr-{_tk}")
289290

290291
assert xx.odc.geobox is not None

0 commit comments

Comments
 (0)