Skip to content

Commit 4d610f6

Browse files
authored
Feature/gzip mtx output (#112)
* Use pandas to compress mtx output * Add test for mtx export * Blackify * Revert "Blackify" This reverts commit 772f5a2. * Fix mtx export test * Run unit tests before integration tests * Fix mtx export test directory sort * add a round-trip check for matrix writing * poke ci
1 parent ac6a7e0 commit 4d610f6

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

.github/workflows/python-package.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ jobs:
4242
pip install $(pwd)/scanpy-scripts
4343
python -m pip install $(pwd)/scanpy --no-deps --ignore-installed -vv
4444
45+
- name: Run unit tests
46+
run: pytest --doctest-modules -v ./scanpy-scripts
47+
4548
- name: Test with bats
4649
run: |
4750
./scanpy-scripts/scanpy-scripts-tests.bats
48-
49-
- name: Run unit tests
50-
run: pytest --doctest-modules -v ./scanpy-scripts
51-

scanpy_scripts/cmd_utils.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _write_obj(
120120
click.echo(adata, err=show_obj == 'stderr')
121121
return 0
122122

123-
def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_layer=None):
123+
def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_layer=None, compression = None):
124124
"""Export AnnData object to mtx formt
125125
* Parameters
126126
+ adata : AnnData
@@ -133,6 +133,32 @@ def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_lay
133133
A list of column names to be exported to gene table
134134
+ obs : list
135135
A list of column names to be exported to barcode/cell table
136+
+ use_raw : bool
137+
Take data the matrix from .raw.X?
138+
+ use_layer: str
139+
Specify a layer to use instead of .X (non-raw only)
140+
+ compression: None, str or dict
141+
Compression parameter for Pandas' to_csv(). For compression, a dict
142+
with a 'method' key, e.g. {'method': 'gzip', 'compresslevel': 1,
143+
'mtime': 1}
144+
145+
>>> import os
146+
>>> from pathlib import Path
147+
>>> adata = sc.datasets.pbmc3k()
148+
>>> # Test uncompressed write
149+
>>> Path("uncompressed").mkdir(parents=True, exist_ok=True)
150+
>>> write_mtx(adata, fname_prefix = 'uncompressed/', use_raw = False, use_layer = None, var = ['gene_name'])
151+
>>> sorted(os.listdir('uncompressed'))
152+
['barcodes.tsv', 'genes.tsv', 'matrix.mtx']
153+
>>> # Test that the matrix is the same when we read it back
154+
>>> test_readable = sc.read_10x_mtx('uncompressed')
155+
>>> if any(test_readable.obs_names != adata.obs_names) or any(test_readable.var_names != adata.var_names) or (test_readable.X[1].sum() - adata.X[1].sum()) > 1e-5:
156+
... print("Re-read matrix is different to the one we stored, something is wrong with the writing")
157+
>>> # Test compressed write
158+
>>> Path("compressed").mkdir(parents=True, exist_ok=True)
159+
>>> write_mtx(adata, fname_prefix = 'compressed/', use_raw = False, use_layer = None, var = ['gene_name'], compression = {'method': 'gzip'})
160+
>>> sorted(os.listdir('compressed'))
161+
['barcodes.tsv.gz', 'genes.tsv.gz', 'matrix.mtx.gz']
136162
"""
137163
if fname_prefix and not (fname_prefix.endswith('/') or fname_prefix.endswith('_')):
138164
fname_prefix = fname_prefix + '_'
@@ -157,22 +183,46 @@ def write_mtx(adata, fname_prefix='', var=None, obs=None, use_raw=False, use_lay
157183

158184
n_obs, n_var = mat.shape
159185
n_entry = len(mat.data)
160-
header = '%%MatrixMarket matrix coordinate real general\n%\n{} {} {}\n'.format(
161-
n_var, n_obs, n_entry)
186+
187+
# Define the header lines as a Pandas DataFrame so we can use the same compression
188+
header = pd.DataFrame(['%%MatrixMarket matrix coordinate real general', f"{n_var} {n_obs} {n_entry}"])
162189
df = pd.DataFrame({'col': mat.col + 1, 'row': mat.row + 1, 'data': mat.data})
190+
191+
# Define outputs
163192
mtx_fname = fname_prefix + 'matrix.mtx'
164193
gene_fname = fname_prefix + 'genes.tsv'
165194
barcode_fname = fname_prefix + 'barcodes.tsv'
166-
with open(mtx_fname, 'a') as fh:
167-
fh.write(header)
168-
df.to_csv(fh, sep=' ', header=False, index=False)
169195

196+
# Write matrix with Pandas CSV and use its compression where requested
197+
if compression is not None and type(compression) is dict and 'method' in compression:
198+
compressed_exts = {
199+
'zip': 'zip',
200+
'gzip': 'gz',
201+
'bz2': 'bz2',
202+
'zstd': 'zst'
203+
}
204+
ext = compressed_exts.get(compression['method'], 'None')
205+
206+
if ext is None:
207+
errmsg = "Invalid compression method"
208+
raise Exception(errmsg)
209+
210+
mtx_fname += f".{ext}"
211+
gene_fname += f".{ext}"
212+
barcode_fname += f".{ext}"
213+
else:
214+
compression = None
215+
216+
header.to_csv(mtx_fname, header = False, index = False, compression = compression)
217+
df.to_csv(mtx_fname, sep=' ', header=False, index=False, compression = compression, mode = 'a')
218+
219+
# Now write the obs and var, also with compression if appropriate
170220
obs_df = adata.obs[obs].reset_index(level=0)
171-
obs_df.to_csv(barcode_fname, sep='\t', header=False, index=False)
221+
obs_df.to_csv(barcode_fname, sep='\t', header=False, index=False, compression = compression)
172222
var_df = var_source[var].reset_index(level=0)
173223
if not var:
174224
var_df['gene'] = var_df['index']
175-
var_df.to_csv(gene_fname, sep='\t', header=False, index=False)
225+
var_df.to_csv(gene_fname, sep='\t', header=False, index=False, compression = compression)
176226

177227

178228
def make_plot_function(func_name, kind=None):

0 commit comments

Comments
 (0)