diff --git a/completion/_pycpio b/completion/_pycpio index 2349f92..549c94c 100644 --- a/completion/_pycpio +++ b/completion/_pycpio @@ -24,8 +24,8 @@ args=("-i[input archive]:archive:_files" "--set-group[GID]:GID:_groups" "-m[file mode]:mode" "--set-mode[file modes]:mode" - "-z[compression]:compression:(xz)" - "--compress[compression]:compression:(xz)" + "-z[compression]:compression:(xz zstd)" + "--compress[compression]:compression:(xz zstd)" "-o[output archive]:archive:_files" "--output[output archive]:archive:_files" "--log-file[specify the log file]:log file:_files" diff --git a/readme.md b/readme.md index eb1649f..70276a6 100644 --- a/readme.md +++ b/readme.md @@ -6,7 +6,9 @@ A library for creating CPIO files in Python. -Currently, the library only supports the New ASCII format, and xz compression +Currently, the library only supports the New ASCII format. + +xz compression and zstd compression types are suppored. This library is primary designed for use in [ugrd](https://github.com/desultory/ugrd) to create CPIO archives for use in initramfs. diff --git a/src/pycpio/writer/writer.py b/src/pycpio/writer/writer.py index 0417971..869aee2 100644 --- a/src/pycpio/writer/writer.py +++ b/src/pycpio/writer/writer.py @@ -20,6 +20,7 @@ def __init__( output_file: Path, structure=None, compression=False, + compression_level=10, xz_crc=CHECK_CRC32, *args, **kwargs, @@ -30,6 +31,7 @@ def __init__( self.structure = structure if structure is not None else HEADER_NEW self.compression = compression or False + self.compression_level = compression_level or 10 if isinstance(compression, str): compression = compression.lower() if compression == "true": @@ -50,18 +52,36 @@ def __bytes__(self): def compress(self, data): """Attempts to compress the data using the specified compression type.""" + compression_kwargs = {} + compression_args = () if self.compression == "xz" or self.compression is True: - import lzma - - self.logger.info("XZ compressing the CPIO data, original size: %.2f MiB" % (len(data) / (2**20))) - data = lzma.compress(data, check=self.xz_crc) + compression_module = "lzma.compress" + compression_kwargs["check"] = self.xz_crc elif self.compression == "zstd": - import zstd - - self.logger.info("ZSTD compressing the CPIO data, original size: %.2f MiB" % (len(data) / (2**20))) - data = zstd.compress(data, 10) + compression_module = "zstd.compress" + compression_args = (self.compression_level,) elif self.compression is not False: raise NotImplementedError("Compression type not supported: %s" % self.compression) + else: + self.logger.info("No compression specified, writing uncompressed data.") + return data + + try: + if "." in compression_module: + module, func = compression_module.rsplit(".", 1) + else: + module, func = compression_module, "compress" + + compressor = getattr(__import__(module), func) + self.logger.debug("Compressing data with: %s" % compression_module) + except ImportError as e: + raise ImportError("Failed to import compression module: %s" % compression_module) from e + + self.logger.info( + "[%s] Compressing the CPIO data, original size: %.2f MiB" % (self.compression.upper(), len(data) / (2**20)) + ) + data = compressor(data, *compression_args, **compression_kwargs) + return data def write(self, safe_write=True): diff --git a/tests/test_compression.py b/tests/test_compression.py new file mode 100644 index 0000000..5e29626 --- /dev/null +++ b/tests/test_compression.py @@ -0,0 +1,79 @@ +from tempfile import NamedTemporaryFile, TemporaryDirectory +from unittest import TestCase, main +from uuid import uuid4 + +from pycpio import PyCPIO +from zenlib.logging import loggify + + +@loggify +class TestCpio(TestCase): + def setUp(self): + self.cpio = PyCPIO(logger=self.logger) + self.make_workdir() + + def tearDown(self): + for file in self.test_files: + file.close() + for directory in self.test_dirs: + directory.cleanup() + self.workdir.cleanup() + del self.cpio + + def make_workdir(self): + """ + Create a temporary directory for testing. + sets self.workdir to the Path object of the directory + initializes self.test_files as an empty list + """ + self.workdir = TemporaryDirectory(prefix="pycpio-test-") + self.test_files = [] + self.test_dirs = [] + + def make_test_file(self, subdir=None, data=None): + """Creates a test file in the workdir""" + base_dir = self.workdir.name + if subdir is True: + d = TemporaryDirectory(dir=base_dir) + self.test_dirs.append(d) + base_dir = d.name + elif subdir is not None and subdir in self.test_dirs: + base_dir = subdir.name + + file = NamedTemporaryFile(dir=base_dir) + file_data = data.encode() if data is not None else bytes(str(uuid4()), "utf-8") + file.file.write(file_data) + file.file.flush() + + self.test_files.append(file) + return file + + def make_test_files(self, count, subdir=None, data=None): + """Creates count test files in the workdir""" + for _ in range(count): + self.make_test_file(subdir=subdir, data=data) + + def test_write_no_compress(self): + self.make_test_files(100) + self.cpio.append_cpio(self.workdir.name) + out_file = NamedTemporaryFile() # Out file for the cpio + self.cpio.write_cpio_file(out_file.file.name) + out_file.file.flush() + + def test_write_xz_compress(self): + self.make_test_files(100) + self.cpio.append_cpio(self.workdir.name) + out_file = NamedTemporaryFile() + self.cpio.write_cpio_file(out_file.file.name, compress="xz") + out_file.file.flush() + + def test_write_zstd_compress(self): + self.make_test_files(100) + self.cpio.append_cpio(self.workdir.name) + out_file = NamedTemporaryFile() + self.cpio.write_cpio_file(out_file.file.name, compress="zstd") + out_file.file.flush() + + +if __name__ == "__main__": + main()