@@ -120,7 +120,7 @@ def _write_obj(
120
120
click .echo (adata , err = show_obj == 'stderr' )
121
121
return 0
122
122
123
- def write_mtx (adata , fname_prefix = '' , var = None , obs = None , use_raw = False , use_layer = None ):
123
+ def write_mtx (adata , fname_prefix = '' , var = None , obs = None , use_raw = False , use_layer = None , compression = None ):
124
124
"""Export AnnData object to mtx formt
125
125
* Parameters
126
126
+ adata : AnnData
@@ -133,6 +133,32 @@ def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_lay
133
133
A list of column names to be exported to gene table
134
134
+ obs : list
135
135
A list of column names to be exported to barcode/cell table
136
+ + use_raw : bool
137
+ Take data the matrix from .raw.X?
138
+ + use_layer: str
139
+ Specify a layer to use instead of .X (non-raw only)
140
+ + compression: None, str or dict
141
+ Compression parameter for Pandas' to_csv(). For compression, a dict
142
+ with a 'method' key, e.g. {'method': 'gzip', 'compresslevel': 1,
143
+ 'mtime': 1}
144
+
145
+ >>> import os
146
+ >>> from pathlib import Path
147
+ >>> adata = sc.datasets.pbmc3k()
148
+ >>> # Test uncompressed write
149
+ >>> Path("uncompressed").mkdir(parents=True, exist_ok=True)
150
+ >>> write_mtx(adata, fname_prefix = 'uncompressed/', use_raw = False, use_layer = None, var = ['gene_name'])
151
+ >>> sorted(os.listdir('uncompressed'))
152
+ ['barcodes.tsv', 'genes.tsv', 'matrix.mtx']
153
+ >>> # Test that the matrix is the same when we read it back
154
+ >>> test_readable = sc.read_10x_mtx('uncompressed')
155
+ >>> if any(test_readable.obs_names != adata.obs_names) or any(test_readable.var_names != adata.var_names) or (test_readable.X[1].sum() - adata.X[1].sum()) > 1e-5:
156
+ ... print("Re-read matrix is different to the one we stored, something is wrong with the writing")
157
+ >>> # Test compressed write
158
+ >>> Path("compressed").mkdir(parents=True, exist_ok=True)
159
+ >>> write_mtx(adata, fname_prefix = 'compressed/', use_raw = False, use_layer = None, var = ['gene_name'], compression = {'method': 'gzip'})
160
+ >>> sorted(os.listdir('compressed'))
161
+ ['barcodes.tsv.gz', 'genes.tsv.gz', 'matrix.mtx.gz']
136
162
"""
137
163
if fname_prefix and not (fname_prefix .endswith ('/' ) or fname_prefix .endswith ('_' )):
138
164
fname_prefix = fname_prefix + '_'
@@ -157,22 +183,46 @@ def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_lay
157
183
158
184
n_obs , n_var = mat .shape
159
185
n_entry = len (mat .data )
160
- header = '%%MatrixMarket matrix coordinate real general\n %\n {} {} {}\n ' .format (
161
- n_var , n_obs , n_entry )
186
+
187
+ # Define the header lines as a Pandas DataFrame so we can use the same compression
188
+ header = pd .DataFrame (['%%MatrixMarket matrix coordinate real general' , f"{ n_var } { n_obs } { n_entry } " ])
162
189
df = pd .DataFrame ({'col' : mat .col + 1 , 'row' : mat .row + 1 , 'data' : mat .data })
190
+
191
+ # Define outputs
163
192
mtx_fname = fname_prefix + 'matrix.mtx'
164
193
gene_fname = fname_prefix + 'genes.tsv'
165
194
barcode_fname = fname_prefix + 'barcodes.tsv'
166
- with open (mtx_fname , 'a' ) as fh :
167
- fh .write (header )
168
- df .to_csv (fh , sep = ' ' , header = False , index = False )
169
195
196
+ # Write matrix with Pandas CSV and use its compression where requested
197
+ if compression is not None and type (compression ) is dict and 'method' in compression :
198
+ compressed_exts = {
199
+ 'zip' : 'zip' ,
200
+ 'gzip' : 'gz' ,
201
+ 'bz2' : 'bz2' ,
202
+ 'zstd' : 'zst'
203
+ }
204
+ ext = compressed_exts .get (compression ['method' ], 'None' )
205
+
206
+ if ext is None :
207
+ errmsg = "Invalid compression method"
208
+ raise Exception (errmsg )
209
+
210
+ mtx_fname += f".{ ext } "
211
+ gene_fname += f".{ ext } "
212
+ barcode_fname += f".{ ext } "
213
+ else :
214
+ compression = None
215
+
216
+ header .to_csv (mtx_fname , header = False , index = False , compression = compression )
217
+ df .to_csv (mtx_fname , sep = ' ' , header = False , index = False , compression = compression , mode = 'a' )
218
+
219
+ # Now write the obs and var, also with compression if appropriate
170
220
obs_df = adata .obs [obs ].reset_index (level = 0 )
171
- obs_df .to_csv (barcode_fname , sep = '\t ' , header = False , index = False )
221
+ obs_df .to_csv (barcode_fname , sep = '\t ' , header = False , index = False , compression = compression )
172
222
var_df = var_source [var ].reset_index (level = 0 )
173
223
if not var :
174
224
var_df ['gene' ] = var_df ['index' ]
175
- var_df .to_csv (gene_fname , sep = '\t ' , header = False , index = False )
225
+ var_df .to_csv (gene_fname , sep = '\t ' , header = False , index = False , compression = compression )
176
226
177
227
178
228
def make_plot_function (func_name , kind = None ):
0 commit comments