Skip to content

Commit 916d082

Browse files
authored
Merge pull request #11 from desultory/dev
Add compression tests, level setting
2 parents e4a0383 + 533ce9e commit 916d082

File tree

4 files changed

+112
-11
lines changed

4 files changed

+112
-11
lines changed

completion/_pycpio

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ args=("-i[input archive]:archive:_files"
2424
"--set-group[GID]:GID:_groups"
2525
"-m[file mode]:mode"
2626
"--set-mode[file modes]:mode"
27-
"-z[compression]:compression:(xz)"
28-
"--compress[compression]:compression:(xz)"
27+
"-z[compression]:compression:(xz zstd)"
28+
"--compress[compression]:compression:(xz zstd)"
2929
"-o[output archive]:archive:_files"
3030
"--output[output archive]:archive:_files"
3131
"--log-file[specify the log file]:log file:_files"

readme.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
A library for creating CPIO files in Python.
88

9-
Currently, the library only supports the New ASCII format, and xz compression
9+
Currently, the library only supports the New ASCII format.
10+
11+
xz compression and zstd compression types are suppored.
1012

1113
This library is primary designed for use in [ugrd](https://github.com/desultory/ugrd) to create CPIO archives for use in initramfs.
1214

src/pycpio/writer/writer.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
output_file: Path,
2121
structure=None,
2222
compression=False,
23+
compression_level=10,
2324
xz_crc=CHECK_CRC32,
2425
*args,
2526
**kwargs,
@@ -30,6 +31,7 @@ def __init__(
3031
self.structure = structure if structure is not None else HEADER_NEW
3132

3233
self.compression = compression or False
34+
self.compression_level = compression_level or 10
3335
if isinstance(compression, str):
3436
compression = compression.lower()
3537
if compression == "true":
@@ -50,18 +52,36 @@ def __bytes__(self):
5052

5153
def compress(self, data):
5254
"""Attempts to compress the data using the specified compression type."""
55+
compression_kwargs = {}
56+
compression_args = ()
5357
if self.compression == "xz" or self.compression is True:
54-
import lzma
55-
56-
self.logger.info("XZ compressing the CPIO data, original size: %.2f MiB" % (len(data) / (2**20)))
57-
data = lzma.compress(data, check=self.xz_crc)
58+
compression_module = "lzma.compress"
59+
compression_kwargs["check"] = self.xz_crc
5860
elif self.compression == "zstd":
59-
import zstd
60-
61-
self.logger.info("ZSTD compressing the CPIO data, original size: %.2f MiB" % (len(data) / (2**20)))
62-
data = zstd.compress(data, 10)
61+
compression_module = "zstd.compress"
62+
compression_args = (self.compression_level,)
6363
elif self.compression is not False:
6464
raise NotImplementedError("Compression type not supported: %s" % self.compression)
65+
else:
66+
self.logger.info("No compression specified, writing uncompressed data.")
67+
return data
68+
69+
try:
70+
if "." in compression_module:
71+
module, func = compression_module.rsplit(".", 1)
72+
else:
73+
module, func = compression_module, "compress"
74+
75+
compressor = getattr(__import__(module), func)
76+
self.logger.debug("Compressing data with: %s" % compression_module)
77+
except ImportError as e:
78+
raise ImportError("Failed to import compression module: %s" % compression_module) from e
79+
80+
self.logger.info(
81+
"[%s] Compressing the CPIO data, original size: %.2f MiB" % (self.compression.upper(), len(data) / (2**20))
82+
)
83+
data = compressor(data, *compression_args, **compression_kwargs)
84+
6585
return data
6686

6787
def write(self, safe_write=True):

tests/test_compression.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from tempfile import NamedTemporaryFile, TemporaryDirectory
2+
from unittest import TestCase, main
3+
from uuid import uuid4
4+
5+
from pycpio import PyCPIO
6+
from zenlib.logging import loggify
7+
8+
9+
@loggify
10+
class TestCpio(TestCase):
11+
def setUp(self):
12+
self.cpio = PyCPIO(logger=self.logger)
13+
self.make_workdir()
14+
15+
def tearDown(self):
16+
for file in self.test_files:
17+
file.close()
18+
for directory in self.test_dirs:
19+
directory.cleanup()
20+
self.workdir.cleanup()
21+
del self.cpio
22+
23+
def make_workdir(self):
24+
"""
25+
Create a temporary directory for testing.
26+
sets self.workdir to the Path object of the directory
27+
initializes self.test_files as an empty list
28+
"""
29+
self.workdir = TemporaryDirectory(prefix="pycpio-test-")
30+
self.test_files = []
31+
self.test_dirs = []
32+
33+
def make_test_file(self, subdir=None, data=None):
34+
"""Creates a test file in the workdir"""
35+
base_dir = self.workdir.name
36+
if subdir is True:
37+
d = TemporaryDirectory(dir=base_dir)
38+
self.test_dirs.append(d)
39+
base_dir = d.name
40+
elif subdir is not None and subdir in self.test_dirs:
41+
base_dir = subdir.name
42+
43+
file = NamedTemporaryFile(dir=base_dir)
44+
file_data = data.encode() if data is not None else bytes(str(uuid4()), "utf-8")
45+
file.file.write(file_data)
46+
file.file.flush()
47+
48+
self.test_files.append(file)
49+
return file
50+
51+
def make_test_files(self, count, subdir=None, data=None):
52+
"""Creates count test files in the workdir"""
53+
for _ in range(count):
54+
self.make_test_file(subdir=subdir, data=data)
55+
56+
def test_write_no_compress(self):
57+
self.make_test_files(100)
58+
self.cpio.append_cpio(self.workdir.name)
59+
out_file = NamedTemporaryFile() # Out file for the cpio
60+
self.cpio.write_cpio_file(out_file.file.name)
61+
out_file.file.flush()
62+
63+
def test_write_xz_compress(self):
64+
self.make_test_files(100)
65+
self.cpio.append_cpio(self.workdir.name)
66+
out_file = NamedTemporaryFile()
67+
self.cpio.write_cpio_file(out_file.file.name, compress="xz")
68+
out_file.file.flush()
69+
70+
def test_write_zstd_compress(self):
71+
self.make_test_files(100)
72+
self.cpio.append_cpio(self.workdir.name)
73+
out_file = NamedTemporaryFile()
74+
self.cpio.write_cpio_file(out_file.file.name, compress="zstd")
75+
out_file.file.flush()
76+
77+
78+
if __name__ == "__main__":
79+
main()

0 commit comments

Comments
 (0)