|
1 | | -# pylint: disable=no-self-use |
| 1 | +# pylint: disable=no-self-use,protected-access |
2 | 2 | import unittest |
3 | 3 | from unittest import mock |
4 | 4 | from contextlib import ExitStack |
@@ -306,6 +306,51 @@ class Dialect(csv.excel): |
306 | 306 | df = owcsvimport.load_csv(io.BytesIO(contents), opts) |
307 | 307 | assert_array_equal(df.values, np.array([[3.21, 3.37], [4.13, 1000.142]])) |
308 | 308 |
|
| 309 | + def test_open_compressed(self): |
| 310 | + content = 'abc' |
| 311 | + for ext in [None, "gz", "bz2", "xz", "zip"]: |
| 312 | + with named_file('', suffix=f".{ext}") as fname: |
| 313 | + with _open_write(fname, "wt", encoding="ascii") as f: |
| 314 | + f.write(content) |
| 315 | + f.close() |
| 316 | + |
| 317 | + with owcsvimport._open(fname, "rt", encoding="ascii") as f: |
| 318 | + self.assertEqual(content, f.read()) |
| 319 | + |
| 320 | + |
| 321 | +def _open_write(path, mode, encoding=None): |
| 322 | + # pylint: disable=import-outside-toplevel |
| 323 | + if mode not in {'w', 'wb', 'wt'}: |
| 324 | + raise ValueError('r') |
| 325 | + _, ext = os.path.splitext(path) |
| 326 | + ext = ext.lower() |
| 327 | + if ext == ".gz": |
| 328 | + import gzip |
| 329 | + return gzip.open(path, mode, encoding=encoding) |
| 330 | + elif ext == ".bz2": |
| 331 | + import bz2 |
| 332 | + return bz2.open(path, mode, encoding=encoding) |
| 333 | + elif ext == ".xz": |
| 334 | + import lzma |
| 335 | + return lzma.open(path, mode, encoding=encoding) |
| 336 | + elif ext == ".zip": |
| 337 | + import zipfile |
| 338 | + arh = zipfile.ZipFile(path, 'w') |
| 339 | + filename, _ = os.path.splitext(os.path.basename(path)) |
| 340 | + f = arh.open(filename, mode="w") |
| 341 | + f_close = f.close |
| 342 | + # patch the f.close to also close the main archive file |
| 343 | + |
| 344 | + def close_(): |
| 345 | + f_close() |
| 346 | + arh.close() |
| 347 | + f.close = close_ |
| 348 | + if 't' in mode: |
| 349 | + f = io.TextIOWrapper(f, encoding=encoding) |
| 350 | + return f |
| 351 | + else: |
| 352 | + return open(path, mode, encoding=encoding) |
| 353 | + |
309 | 354 |
|
310 | 355 | if __name__ == "__main__": |
311 | 356 | unittest.main() |
0 commit comments