44import os
55from dask .distributed import Client , wait
66import time
7- import dask .array as da
87import copy
98from zarrify .utils .volume import Volume
109from abc import ABCMeta
1110from numcodecs import Zstd
11+ from dask .array .core import slices_from_chunks , normalize_chunks
1212import logging
1313
14+ logging .basicConfig (
15+ level = logging .INFO ,
16+ format = "%(asctime)s %(levelname)s %(name)s: %(message)s"
17+ )
18+ logger = logging .getLogger (__name__ )
1419
15- class Tiff3D (Volume ):
20+
21+ class Tiff (Volume ):
1622
1723 def __init__ (
1824 self ,
@@ -21,6 +27,7 @@ def __init__(
2127 scale : list [float ],
2228 translation : list [float ],
2329 units : list [str ],
30+ optimize_reads : bool = False ,
2431 ):
2532 """Construct all the necessary attributes for the proper conversion of tiff to OME-NGFF Zarr.
2633
@@ -35,57 +42,78 @@ def __init__(
3542 self .shape = self .zarr_arr .shape
3643 self .dtype = self .zarr_arr .dtype
3744 self .ndim = self .zarr_arr .ndim
38-
45+ self .optimize_reads = optimize_reads
46+
3947 # Scale metadata parameters to match data dimensionality
40- self .metadata ["axes" ] = self .metadata ["axes" ][- self .ndim :]
48+ self .metadata ["axes" ] = list ( self .metadata ["axes" ]) [- self .ndim :]
4149 self .metadata ["scale" ] = self .metadata ["scale" ][- self .ndim :]
4250 self .metadata ["translation" ] = self .metadata ["translation" ][- self .ndim :]
4351 self .metadata ["units" ] = self .metadata ["units" ][- self .ndim :]
4452
4553 def write_to_zarr (self ,
46- dest : str ,
54+ zarr_array : zarr . Array ,
4755 client : Client ,
48- zarr_chunks : list [int ],
49- comp : ABCMeta = Zstd (level = 6 ),
5056 ):
5157
52- # reshape chunk shape to align with arr shape
53- if len (zarr_chunks ) != self .shape :
54- zarr_chunks = self .reshape_to_arr_shape (zarr_chunks , self .shape )
55-
56- z_arr = self .get_output_array (dest , zarr_chunks , comp )
57- chunks_list = np .arange (0 , z_arr .shape [0 ], z_arr .chunks [0 ])
58+ # Find slab axis based on metadata axes - use z axis for slabbing
59+ axes = self .metadata ["axes" ]
60+ slab_axis = axes .index ('z' ) if 'z' in axes else 0
61+
62+ z_arr = zarr_array
63+
64+ slice_chunks = z_arr .chunks
65+ if self .optimize_reads :
66+ logger .info ("Optimizing read chunking..." )
67+ # TODO: this works for some cases, doesn't work in others, need to understand why
68+ #slicing
69+ #(c, z, y, x) or (z, y, x) - combine (c, z) from zarr_chunks and (y, x) from tiff chunking
70+ logger .info (f"Output Zarr array chunks: { z_arr .chunks } " )
71+ logger .info (f"Input Tiff array chunks: { self .zarr_arr .chunks } " )
72+ slice_chunks = list (z_arr .chunks [:slab_axis + 1 ]).copy ()
73+ logger .info (f"Slice chunks: { slice_chunks } " )
74+
75+ # cast slab size to write into zarr:
76+ for zarr_chunkdim , tiff_chunkdim , tiff_dim in zip (z_arr .chunks [slab_axis + 1 :], self .zarr_arr .chunks [slab_axis + 1 :], self .zarr_arr .shape [slab_axis + 1 :]):
77+ if tiff_chunkdim < zarr_chunkdim :
78+ slice_chunks .append (zarr_chunkdim )
79+ elif tiff_chunkdim / tiff_dim < 0.5 :
80+ slice_chunks .append (int (tiff_chunkdim / zarr_chunkdim )* zarr_chunkdim )
81+ else :
82+ slice_chunks .append (tiff_dim )
83+
84+ logger .info (f"Slice chunks extended: { slice_chunks } " )
85+
86+ # compute size of the slab
87+ slab_size_bytes = np .prod (slice_chunks ) * np .dtype (self .dtype ).itemsize
88+
89+ # get dask worker allocated memery size
90+ dask_worker_memory_bytes = next (iter (client .scheduler_info ()["workers" ].values ()))["memory_limit" ]
91+
92+ logger .info (f"Slab size: { slab_size_bytes / 1e9 } GB" )
93+ logger .info (f"Dask memory limit: { dask_worker_memory_bytes / 1e9 } GB" )
94+ if slab_size_bytes > dask_worker_memory_bytes :
95+ raise ValueError ("Tiff segment size exceeds Dask worker memory limit. Please reduce the chunksize of the output array." )
96+
97+ logger .info (f"Zarr array shape: { self .zarr_arr .shape } " )
98+ normalized_chunks = normalize_chunks (slice_chunks , shape = self .zarr_arr .shape )
99+ slice_tuples = slices_from_chunks (normalized_chunks )
58100
59101 src_path = copy .copy (self .src_path )
60102
61103 start = time .time ()
62104 fut = client .map (
63- lambda v : write_volume_slab_to_zarr (v , z_arr , src_path ), chunks_list
64- )
65- logging .info (
66- f"Submitted { len (chunks_list )} tasks to the scheduler in { time .time ()- start } s"
105+ lambda v : write_volume_slab_to_zarr (v , z_arr , src_path ), slice_tuples
67106 )
107+ logger .info (f"Submitted { len (slice_tuples )} tasks to the scheduler in { round (time .time ()- start , 4 )} s" )
68108
69109 # wait for all the futures to complete
70110 result = wait (fut )
71- logging .info (f"Completed { len (chunks_list )} tasks in { time .time () - start } s" )
111+ logger .info (f"Completed { len (slice_tuples )} tasks in { round ( time .time () - start , 2 ) } s" )
72112
73113 return 0
74114
75115
76- def write_volume_slab_to_zarr (chunk_num : int , zarray : zarr .Array , src_path : str ):
77-
78- # check if the slab is at the array boundary or not
79- if chunk_num + zarray .chunks [0 ] > zarray .shape [0 ]:
80- slab_thickness = zarray .shape [0 ] - chunk_num
81- else :
82- slab_thickness = zarray .chunks [0 ]
83-
84- slab_shape = [slab_thickness ] + list (zarray .shape [- 2 :])
85- np_slab = np .empty (slab_shape , zarray .dtype )
86-
87- tiff_slab = imread (src_path , key = range (chunk_num , chunk_num + slab_thickness , 1 ))
88- np_slab [0 : zarray .chunks [0 ], :, :] = tiff_slab
89-
90- # write a tiff stack slab into zarr array
91- zarray [chunk_num : chunk_num + zarray .chunks [0 ], :, :] = np_slab
116+ def write_volume_slab_to_zarr (slice : slice , zarray : zarr .Array , src_path : str ):
117+ tiff_store = imread (src_path , aszarr = True )
118+ src_tiff_arr = zarr .open (tiff_store , mode = 'r' )
119+ zarray [slice ] = src_tiff_arr [slice ]
0 commit comments