@@ -121,15 +121,15 @@ def __bool__(self) -> bool:
121121
122122 def resolve_sources (
123123 self , srcs : Sequence [MultiBandRasterSource ]
124- ) -> List [List [RasterSource ]]:
125- out : List [List [RasterSource ]] = []
124+ ) -> List [List [tuple [ int , RasterSource ] ]]:
125+ out : List [List [tuple [ int , RasterSource ] ]] = []
126126
127127 for layer in self .srcs :
128128 _srcs : List [RasterSource ] = []
129129 for idx in layer :
130130 src = srcs [idx ].get (self .band , None )
131131 if src is not None :
132- _srcs .append (src )
132+ _srcs .append (( idx , src ) )
133133 out .append (_srcs )
134134 return out
135135
@@ -263,6 +263,7 @@ def _task_futures(
263263 dask_reader : DaskRasterReader ,
264264 layer_name : str ,
265265 dsk : dict [Key , Any ],
266+ rdr_cache : dict [str , DaskRasterReader ],
266267 ) -> list [list [Key ]]:
267268 # pylint: disable=too-many-locals
268269 srcs = task .resolve_sources (self .srcs )
@@ -273,9 +274,15 @@ def _task_futures(
273274
274275 for i_time , layer in enumerate (srcs , start = task .idx [0 ]):
275276 keys_out : list [Key ] = []
276- for i_src , src in enumerate (layer ):
277- idx = (i_time , * task .idx [1 :], i_src )
278- rdr = dask_reader .open (src , ctx , layer_name = layer_name , idx = i_src )
277+ for i_src , src in layer :
278+ idx = (i_src , i_time , * task .idx [1 :])
279+
280+ src_hash = tokenize (src )
281+ rdr = rdr_cache .get (src_hash , None )
282+ if rdr is None :
283+ rdr = dask_reader .open (src , ctx , layer_name = layer_name , idx = i_src )
284+ rdr_cache [src_hash ] = rdr
285+
279286 fut = rdr .read (cfg , dst_gbox , selection = task .selection , idx = idx )
280287 keys_out .append (fut .key )
281288 dsk .update (fut .dask )
@@ -340,6 +347,7 @@ def __call__(
340347 cfg ,
341348 resolve_src_nodata (cfg .fill_value , cfg ),
342349 )
350+ rdr_cache : dict [str , DaskRasterReader ] = {}
343351
344352 for task in self .load_tasks (name , shape [0 ]):
345353 task_key : Key = (band_layer , * task .idx )
@@ -359,7 +367,11 @@ def __call__(
359367 )
360368 else :
361369 srcs_futures = self ._task_futures (
362- task , dask_reader , open_layer , layers [open_layer ]
370+ task ,
371+ dask_reader ,
372+ open_layer ,
373+ layers [open_layer ],
374+ rdr_cache = rdr_cache ,
363375 )
364376
365377 dsk [task_key ] = (
@@ -812,7 +824,7 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
812824
813825 with rdr .restore_env (env , load_state ) as ctx :
814826 for t_idx , layer in enumerate (layers ):
815- loaders = [rdr .open (src , ctx ) for src in layer ]
827+ loaders = [rdr .open (src , ctx ) for _ , src in layer ]
816828 _ = _fill_nd_slice (
817829 loaders ,
818830 task .dst_gbox ,
0 commit comments