diff --git a/dissect/cstruct/cstruct.py b/dissect/cstruct/cstruct.py index 81369ab..23d8f1a 100644 --- a/dissect/cstruct/cstruct.py +++ b/dissect/cstruct/cstruct.py @@ -321,7 +321,7 @@ def _make_type( size: int | None, *, alignment: int | None = None, - attrs: dict[str, Any] = None, + attrs: dict[str, Any] | None = None, ) -> type[BaseType]: """Create a new type class bound to this cstruct instance. diff --git a/dissect/cstruct/types/base.py b/dissect/cstruct/types/base.py index 5ea6fdb..53c57b8 100644 --- a/dissect/cstruct/types/base.py +++ b/dissect/cstruct/types/base.py @@ -36,14 +36,14 @@ def __call__(cls, *args, **kwargs) -> MetaType | BaseType: if len(args) == 1 and not isinstance(args[0], cls): stream = args[0] - if hasattr(stream, "read"): + if _is_readable_type(stream): return cls._read(stream) if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size: # Shortcut for char/bytes type return type.__call__(cls, *args, **kwargs) - if isinstance(stream, (bytes, memoryview, bytearray)): + if _is_buffer_type(stream): return cls.reads(stream) return type.__call__(cls, *args, **kwargs) @@ -83,7 +83,7 @@ def read(cls, obj: BinaryIO | bytes) -> BaseType: Returns: The parsed value of this type. """ - if isinstance(obj, (bytes, memoryview, bytearray)): + if _is_buffer_type(obj): return cls.reads(obj) return cls._read(obj) @@ -113,7 +113,7 @@ def dumps(cls, value: Any) -> bytes: cls._write(out, value) return out.getvalue() - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType: """Internal function for reading value. Must be implemented per type. @@ -124,7 +124,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType: """ raise NotImplementedError() - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[BaseType]: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[BaseType]: """Internal function for reading array values. Allows type implementations to do optimized reading for their type. @@ -145,7 +145,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non return [cls._read(stream, context) for _ in range(count)] - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[BaseType]: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]: """Internal function for reading null-terminated data. "Null" is type specific, so must be implemented per type. @@ -179,7 +179,7 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: stream: The stream to read from. array: The array to write. """ - return cls._write_array(stream, array + [cls()]) + return cls._write_array(stream, array + [cls.default()]) class _overload: @@ -225,7 +225,10 @@ class ArrayMetaType(MetaType): num_entries: int | Expression | None null_terminated: bool - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array: + def default(cls) -> BaseType: + return type.__call__(cls, [cls.type.default()] * (cls.num_entries if isinstance(cls.num_entries, int) else 0)) + + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array: if cls.null_terminated: return cls.type._read_0(stream, context) @@ -243,11 +246,6 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array: return cls.type._read_array(stream, num, context) - def default(cls) -> BaseType: - return type.__call__( - cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)] - ) - class Array(list, BaseType, metaclass=ArrayMetaType): """Implements a fixed or dynamically sized array type. @@ -261,7 +259,7 @@ class Array(list, BaseType, metaclass=ArrayMetaType): """ @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array: return cls(ArrayMetaType._read(cls, stream, context)) @classmethod @@ -275,5 +273,13 @@ def _write(cls, stream: BinaryIO, data: list[Any]) -> int: return cls.type._write_array(stream, data) +def _is_readable_type(value: Any) -> bool: + return hasattr(value, "read") + + +def _is_buffer_type(value: Any) -> bool: + return isinstance(value, (bytes, memoryview, bytearray)) + + # As mentioned in the BaseType class, we correctly set the type here MetaType.ArrayType = Array diff --git a/dissect/cstruct/types/char.py b/dissect/cstruct/types/char.py index cec72c1..268d690 100644 --- a/dissect/cstruct/types/char.py +++ b/dissect/cstruct/types/char.py @@ -9,7 +9,7 @@ class CharArray(bytes, BaseType, metaclass=ArrayMetaType): """Character array type for reading and writing byte strings.""" @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> CharArray: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> CharArray: return type.__call__(cls, ArrayMetaType._read(cls, stream, context)) @classmethod @@ -35,11 +35,11 @@ class Char(bytes, BaseType): ArrayType = CharArray @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char: return cls._read_array(stream, 1, context) @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Char: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Char: if count == 0: return type.__call__(cls, b"") @@ -50,7 +50,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non return type.__call__(cls, data) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char: buf = [] while True: byte = stream.read(1) diff --git a/dissect/cstruct/types/enum.py b/dissect/cstruct/types/enum.py index cae53c3..eea2b7a 100644 --- a/dissect/cstruct/types/enum.py +++ b/dissect/cstruct/types/enum.py @@ -27,7 +27,7 @@ def __call__( ) -> EnumMetaType: if name is None: if value is None: - value = cls.type() + value = cls.type.default() if not isinstance(value, int): # value is a parsable value @@ -64,13 +64,13 @@ def __contains__(cls, value: Any) -> bool: return True return value in cls._value2member_map_ - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Enum: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Enum: return cls(cls.type._read(stream, context)) - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Enum]: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Enum]: return list(map(cls, cls.type._read_array(stream, count, context))) - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Enum]: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Enum]: return list(map(cls, cls.type._read_0(stream, context))) def _write(cls, stream: BinaryIO, data: Enum) -> int: @@ -82,7 +82,7 @@ def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int: def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: data = [entry.value if isinstance(entry, Enum) else entry for entry in array] - return cls._write_array(stream, data + [cls.type()]) + return cls._write_array(stream, data + [cls.type.default()]) def _fix_alias_members(cls: type[Enum]) -> None: diff --git a/dissect/cstruct/types/int.py b/dissect/cstruct/types/int.py index b1bc29c..f994a4d 100644 --- a/dissect/cstruct/types/int.py +++ b/dissect/cstruct/types/int.py @@ -12,7 +12,7 @@ class Int(int, BaseType): signed: bool @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int: data = stream.read(cls.size) if len(data) != cls.size: @@ -21,7 +21,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int: return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int: result = [] while True: diff --git a/dissect/cstruct/types/leb128.py b/dissect/cstruct/types/leb128.py index 9f0a398..45b4786 100644 --- a/dissect/cstruct/types/leb128.py +++ b/dissect/cstruct/types/leb128.py @@ -14,7 +14,7 @@ class LEB128(int, BaseType): signed: bool @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128: result = 0 shift = 0 while True: @@ -35,7 +35,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: return cls.__new__(cls, result) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128: result = [] while True: diff --git a/dissect/cstruct/types/packed.py b/dissect/cstruct/types/packed.py index ec42c23..493e85c 100644 --- a/dissect/cstruct/types/packed.py +++ b/dissect/cstruct/types/packed.py @@ -18,11 +18,11 @@ class Packed(BaseType): packchar: str @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed: return cls._read_array(stream, 1, context)[0] @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Packed]: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Packed]: if count == EOF: data = stream.read() length = len(data) @@ -39,7 +39,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non return [cls.__new__(cls, value) for value in fmt.unpack(data)] @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed: result = [] fmt = _struct(cls.cs.endian, cls.packchar) diff --git a/dissect/cstruct/types/pointer.py b/dissect/cstruct/types/pointer.py index f79d86d..e398ff6 100644 --- a/dissect/cstruct/types/pointer.py +++ b/dissect/cstruct/types/pointer.py @@ -12,11 +12,11 @@ class Pointer(int, BaseType): """Pointer to some other type.""" type: MetaType - _stream: BinaryIO - _context: dict[str, Any] + _stream: BinaryIO | None + _context: dict[str, Any] | None _value: BaseType - def __new__(cls, value: int, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer: + def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Pointer: obj = super().__new__(cls, value) obj._stream = stream obj._context = context @@ -66,7 +66,11 @@ def __or__(self, other: int) -> Pointer: return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer: + def default(cls) -> Pointer: + return cls.__new__(cls, cls.cs.pointer.default(), None, None) + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pointer: return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context) @classmethod @@ -74,7 +78,7 @@ def _write(cls, stream: BinaryIO, data: int) -> int: return cls.cs.pointer._write(stream, data) def dereference(self) -> Any: - if self == 0: + if self == 0 or self._stream is None: raise NullPointerDereference() if self._value is None and not issubclass(self.type, Void): diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index bc52136..6dbc5cf 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -4,13 +4,19 @@ from contextlib import contextmanager from enum import Enum from functools import lru_cache +from itertools import chain from operator import attrgetter from textwrap import dedent from types import FunctionType -from typing import Any, BinaryIO, Callable, ContextManager +from typing import Any, BinaryIO, Callable, Iterator from dissect.cstruct.bitbuffer import BitBuffer -from dissect.cstruct.types.base import BaseType, MetaType +from dissect.cstruct.types.base import ( + BaseType, + MetaType, + _is_buffer_type, + _is_readable_type, +) from dissect.cstruct.types.enum import EnumMetaType from dissect.cstruct.types.pointer import Pointer @@ -65,7 +71,7 @@ def __call__(cls, *args, **kwargs) -> Structure: # Shortcut for single char/bytes type return type.__call__(cls, *args, **kwargs) elif not args and not kwargs: - obj = cls(**{field.name: field.type.default() for field in cls.__fields__}) + obj = type.__call__(cls) object.__setattr__(obj, "_values", {}) object.__setattr__(obj, "_sizes", {}) return obj @@ -77,7 +83,6 @@ def _update_fields(cls, fields: list[Field], align: bool = False, classdict: dic lookup = {} raw_lookup = {} - init_names = [] field_names = [] for field in fields: if field.name in lookup and field.name != "_": @@ -94,25 +99,21 @@ def _update_fields(cls, fields: list[Field], align: bool = False, classdict: dic raw_lookup[field.name] = field - num_fields = len(lookup) field_names = lookup.keys() - init_names = raw_lookup.keys() classdict["fields"] = lookup classdict["lookup"] = raw_lookup classdict["__fields__"] = fields - classdict["__bool__"] = _patch_attributes(_make__bool__(num_fields), field_names, 1) + classdict["__bool__"] = _generate__bool__(field_names) if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType): - classdict["__init__"] = _patch_setattr_args_and_attributes( - _make_setattr__init__(len(init_names)), init_names - ) + classdict["__init__"] = _generate_union__init__(raw_lookup.values()) # Not a great way to do this but it works for now classdict["__eq__"] = Union.__eq__ else: - classdict["__init__"] = _patch_args_and_attributes(_make__init__(len(init_names)), init_names) - classdict["__eq__"] = _patch_attributes(_make__eq__(num_fields), field_names, 1) + classdict["__init__"] = _generate_structure__init__(raw_lookup.values()) + classdict["__eq__"] = _generate__eq__(field_names) - classdict["__hash__"] = _patch_attributes(_make__hash__(num_fields), field_names, 1) + classdict["__hash__"] = _generate__hash__(field_names) # If we're calling this as a class method or a function on the metaclass if issubclass(cls, type): @@ -229,7 +230,7 @@ def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) - # The structure size is whatever the currently calculated offset is return offset, alignment - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Structure: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Structure: bit_buffer = BitBuffer(stream, cls.cs.endian) struct_start = stream.tell() @@ -271,12 +272,14 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Structure: # Align the stream stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR) - obj = cls(**result) + # Using type.__call__ directly calls the __init__ method of the class + # This is faster than calling cls() and bypasses the metaclass __call__ method + obj = type.__call__(cls, **result) obj._sizes = sizes obj._values = result return obj - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Structure]: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Structure]: result = [] while obj := cls._read(stream, context): @@ -322,7 +325,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int: value = getattr(data, field.name, None) if value is None: - value = field_type() + value = field_type.default() if field.bits: if isinstance(field_type, EnumMetaType): @@ -350,7 +353,7 @@ def add_field(cls, name: str, type_: BaseType, bits: int | None = None, offset: cls.commit() @contextmanager - def start_update(cls) -> ContextManager: + def start_update(cls) -> Iterator[None]: try: cls.__updating__ = True yield @@ -397,11 +400,27 @@ class UnionMetaType(StructureMetaType): """Base metaclass for cstruct union type classes.""" def __call__(cls, *args, **kwargs) -> Union: - obj = super().__call__(*args, **kwargs) - if kwargs: - # Calling with kwargs means we are initializing with values - # Proxify all values + obj: Union = super().__call__(*args, **kwargs) + + # Calling with non-stream args or kwargs means we are initializing with values + if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs: + # We don't support user initialization of dynamic unions yet + if cls.dynamic: + raise NotImplementedError("Initializing a dynamic union is not yet supported") + + # User (partial) initialization, rebuild the union + # First user-provided field is the one used to rebuild the union + arg_fields = (field.name for _, field in zip(args, cls.__fields__)) + kwarg_fields = (name for name in kwargs if name in cls.lookup) + if (first_field := next(chain(arg_fields, kwarg_fields), None)) is not None: + obj._rebuild(first_field) + elif not args and not kwargs: + # Initialized with default values + # Note that we proxify here in case we have a default initialization (cls()) + # We don't proxify in case we read from a stream, as we do that later on in _read at a more appropriate time + # Same with (partial) user initialization, we do that after rebuilding the union obj._proxify() + return obj def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: @@ -425,7 +444,9 @@ def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) - return size, alignment - def _read_fields(cls, stream: BinaryIO, context: dict[str, Any] = None) -> tuple[dict[str, Any], dict[str, int]]: + def _read_fields( + cls, stream: BinaryIO, context: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, int]]: result = {} sizes = {} @@ -451,7 +472,7 @@ def _read_fields(cls, stream: BinaryIO, context: dict[str, Any] = None) -> tuple return result, sizes - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Union: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Union: if cls.size is None: start = stream.tell() result, sizes = cls._read_fields(stream, context) @@ -463,7 +484,12 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Union: sizes = {} buf = stream.read(cls.size) - obj: Union = cls(**result) + # Create the object and set the values + # Using type.__call__ directly calls the __init__ method of the class + # This is faster than calling cls() and bypasses the metaclass __call__ method + # It also makes it easier to differentiate between user-initialization of the class + # and initialization from a stream read + obj: Union = type.__call__(cls, **result) object.__setattr__(obj, "_values", result) object.__setattr__(obj, "_sizes", sizes) object.__setattr__(obj, "_buf", buf) @@ -471,14 +497,20 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Union: if cls.size is not None: obj._update() + # Proxify any nested structures + obj._proxify() + return obj def _write(cls, stream: BinaryIO, data: Union) -> int: + if cls.dynamic: + raise NotImplementedError("Writing dynamic unions is not yet supported") + offset = stream.tell() expected_offset = offset + len(cls) # Sort by largest field - fields = sorted(cls.__fields__, key=lambda e: len(e.type), reverse=True) + fields = sorted(cls.__fields__, key=lambda e: e.type.size or 0, reverse=True) anonymous_struct = False # Try to write by largest field @@ -488,12 +520,8 @@ def _write(cls, stream: BinaryIO, data: Union) -> int: anonymous_struct = field.type continue - # Skip empty values - if (value := getattr(data, field.name)) is None: - continue - - # We have a value, write it - field.type._write(stream, value) + # Write the value + field.type._write(stream, getattr(data, field.name)) break # If we haven't written anything yet and we initially skipped an anonymous struct, write it now @@ -527,14 +555,21 @@ def _rebuild(self, attr: str) -> None: cur_buf = b"\x00" * self.__class__.size buf = io.BytesIO(cur_buf) - field = self.__class__.fields[attr] + field = self.__class__.lookup[attr] if field.offset: buf.seek(field.offset) - field.type._write(buf, getattr(self, attr)) + + if (value := getattr(self, attr)) is None: + value = field.type.default() + + field.type._write(buf, value) object.__setattr__(self, "_buf", buf.getvalue()) self._update() + # (Re-)proxify all values + self._proxify() + def _update(self) -> None: result, sizes = self.__class__._read_fields(io.BytesIO(self._buf)) self.__dict__.update(result) @@ -596,65 +631,71 @@ def _func(obj: Any, value: Any) -> Any: def _codegen(func: FunctionType) -> FunctionType: - # Inspired by https://github.com/dabeaz/dataklasses - @lru_cache - def make_func_code(num_fields: int) -> FunctionType: - names = [f"_{n}" for n in range(num_fields)] - exec(func(names), {}, d := {}) - return d.popitem()[1] + """Decorator that generates a template function with a specified number of fields. - return make_func_code + This code is a little complex but allows use to cache generated functions for a specific number of fields. + For example, if we generate a structure with 10 fields, we can cache the generated code for that structure. + We can then reuse that code and patch it with the correct field names when we create a new structure with 10 fields. + The functions that are decorated with this decorator should take a list of field names and return a string of code. + The decorated function is needs to be called with the number of fields, instead of the field names. + The confusing part is that that the original function takes field names, but you then call it with + the number of fields instead. -def _patch_args_and_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: - return type(func)( - func.__code__.replace( - co_names=(*func.__code__.co_names[:start], *fields), - co_varnames=("self", *fields), - ), - func.__globals__, - argdefs=func.__defaults__, - ) + Inspired by https://github.com/dabeaz/dataklasses. + Args: + func: The decorated function that takes a list of field names and returns a string of code. -def _patch_setattr_args_and_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: - return type(func)( - func.__code__.replace( - co_consts=(None, *fields), - co_varnames=("self", *fields), - ), - func.__globals__, - argdefs=func.__defaults__, - ) + Returns: + A cached function that generates the desired function code, to be called with the number of fields. + """ + def make_func_code(num_fields: int) -> FunctionType: + exec(func([f"_{n}" for n in range(num_fields)]), {}, d := {}) + return d.popitem()[1] -def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: - return type(func)( - func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)), - func.__globals__, - ) + make_func_code.__wrapped__ = func + return lru_cache(make_func_code) @_codegen -def _make__init__(fields: list[str]) -> str: +def _make_structure__init__(fields: list[str]) -> str: + """Generates an ``__init__`` method for a structure with the specified fields. + + Args: + fields: List of field names. + """ field_args = ", ".join(f"{field} = None" for field in fields) - field_init = "\n".join(f" self.{name} = {name}" for name in fields) + field_init = "\n".join(f" self.{name} = {name} if {name} is not None else {i}" for i, name in enumerate(fields)) - code = f"def __init__(self{', ' + field_args if field_args else ''}):\n" + code = f"def __init__(self{', ' + field_args or ''}):\n" return code + (field_init or " pass") @_codegen -def _make_setattr__init__(fields: list[str]) -> str: +def _make_union__init__(fields: list[str]) -> str: + """Generates an ``__init__`` method for a class with the specified fields using setattr. + + Args: + fields: List of field names. + """ field_args = ", ".join(f"{field} = None" for field in fields) - field_init = "\n".join(f" object.__setattr__(self, {name!r}, {name})" for name in fields) + field_init = "\n".join( + f" object.__setattr__(self, '{name}', {name} if {name} is not None else {i})" for i, name in enumerate(fields) + ) - code = f"def __init__(self{', ' + field_args if field_args else ''}):\n" + code = f"def __init__(self{', ' + field_args or ''}):\n" return code + (field_init or " pass") @_codegen def _make__eq__(fields: list[str]) -> str: + """Generates an ``__eq__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ self_vals = ",".join(f"self.{name}" for name in fields) other_vals = ",".join(f"other.{name}" for name in fields) @@ -676,6 +717,11 @@ def __eq__(self, other): @_codegen def _make__bool__(fields: list[str]) -> str: + """Generates a ``__bool__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ vals = ", ".join(f"self.{name}" for name in fields) code = f""" @@ -688,6 +734,11 @@ def __bool__(self): @_codegen def _make__hash__(fields: list[str]) -> str: + """Generates a ``__hash__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ vals = ", ".join(f"self.{name}" for name in fields) code = f""" @@ -696,3 +747,83 @@ def __hash__(self): """ return dedent(code) + + +def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: + """Patches a function's attributes. + + Args: + func: The function to patch. + fields: List of field names to add. + start: The starting index for patching. Defaults to 0. + """ + return type(func)( + func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)), + func.__globals__, + ) + + +def _generate_structure__init__(fields: list[Field]) -> FunctionType: + """Generates an ``__init__`` method for a structure with the specified fields. + + Args: + fields: List of field names. + """ + field_names = [field.name for field in fields] + + template: FunctionType = _make_structure__init__(len(field_names)) + return type(template)( + template.__code__.replace( + co_consts=(None, *[field.type.default() for field in fields]), + co_names=(*field_names,), + co_varnames=("self", *field_names), + ), + template.__globals__, + argdefs=template.__defaults__, + ) + + +def _generate_union__init__(fields: list[Field]) -> FunctionType: + """Generates an ``__init__`` method for a union with the specified fields. + + Args: + fields: List of field names. + """ + field_names = [field.name for field in fields] + + template: FunctionType = _make_union__init__(len(field_names)) + return type(template)( + template.__code__.replace( + co_consts=(None, *sum([(field.name, field.type.default()) for field in fields], ())), + co_varnames=("self", *field_names), + ), + template.__globals__, + argdefs=template.__defaults__, + ) + + +def _generate__eq__(fields: list[str]) -> FunctionType: + """Generates an ``__eq__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ + return _patch_attributes(_make__eq__(len(fields)), fields, 1) + + +def _generate__bool__(fields: list[str]) -> FunctionType: + """Generates a ``__bool__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ + return _patch_attributes(_make__bool__(len(fields)), fields, 1) + + +def _generate__hash__(fields: list[str]) -> FunctionType: + """Generates a ``__hash__`` method for a class with the specified fields. + + Args: + fields: List of field names. + """ + return _patch_attributes(_make__hash__(len(fields)), fields, 1) diff --git a/dissect/cstruct/types/void.py b/dissect/cstruct/types/void.py index 09d5d8b..f3191d0 100644 --- a/dissect/cstruct/types/void.py +++ b/dissect/cstruct/types/void.py @@ -11,10 +11,17 @@ class Void(BaseType): def __bool__(self) -> bool: return False + def __eq__(self, value: object) -> bool: + return isinstance(value, Void) + @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Void: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Void: return cls.__new__(cls) + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Void: + return [cls.__new__(cls)] + @classmethod def _write(cls, stream: BinaryIO, data: Void) -> int: return 0 diff --git a/dissect/cstruct/types/wchar.py b/dissect/cstruct/types/wchar.py index 8799b8b..ecb98dc 100644 --- a/dissect/cstruct/types/wchar.py +++ b/dissect/cstruct/types/wchar.py @@ -10,7 +10,7 @@ class WcharArray(str, BaseType, metaclass=ArrayMetaType): """Wide-character array type for reading and writing UTF-16 strings.""" @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> WcharArray: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> WcharArray: return type.__call__(cls, ArrayMetaType._read(cls, stream, context)) @classmethod @@ -38,11 +38,11 @@ class Wchar(str, BaseType): } @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar: return cls._read_array(stream, 1, context) @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Wchar: + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Wchar: if count == 0: return type.__call__(cls, "") @@ -56,7 +56,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non return type.__call__(cls, data.decode(cls.__encoding_map__[cls.cs.endian])) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar: buf = [] while True: point = stream.read(2) diff --git a/tests/test_basic.py b/tests/test_basic.py index fd3dfca..869c039 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -111,101 +111,6 @@ def test_lookups(cs: cstruct, compiled: bool) -> None: assert cs.lookups["a"] == {1: 3, 2: 4} -def test_default_constructors(cs: cstruct, compiled: bool) -> None: - cdef = """ - enum Enum { - a = 0, - b = 1 - }; - - flag Flag { - a = 0, - b = 1 - }; - - struct test { - uint32 t_int; - uint32 t_int_array[2]; - uint24 t_bytesint; - uint24 t_bytesint_array[2]; - char t_char; - char t_char_array[2]; - wchar t_wchar; - wchar t_wchar_array[2]; - Enum t_enum; - Enum t_enum_array[2]; - Flag t_flag; - Flag t_flag_array[2]; - }; - """ - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - obj = cs.test() - assert obj.t_int == 0 - assert obj.t_int_array == [0, 0] - assert obj.t_bytesint == 0 - assert obj.t_bytesint_array == [0, 0] - assert obj.t_char == b"\x00" - assert obj.t_char_array == b"\x00\x00" - assert obj.t_wchar == "\x00" - assert obj.t_wchar_array == "\x00\x00" - assert obj.t_enum == cs.Enum(0) - assert obj.t_enum_array == [cs.Enum(0), cs.Enum(0)] - assert obj.t_flag == cs.Flag(0) - assert obj.t_flag_array == [cs.Flag(0), cs.Flag(0)] - - assert obj.dumps() == b"\x00" * 54 - - for name in obj.fields.keys(): - assert isinstance(getattr(obj, name), BaseType) - - -def test_default_constructors_dynamic(cs: cstruct, compiled: bool) -> None: - cdef = """ - enum Enum { - a = 0, - b = 1 - }; - flag Flag { - a = 0, - b = 1 - }; - struct test { - uint8 x; - uint32 t_int_array_n[]; - uint32 t_int_array_d[x]; - uint24 t_bytesint_array_n[]; - uint24 t_bytesint_array_d[x]; - char t_char_array_n[]; - char t_char_array_d[x]; - wchar t_wchar_array_n[]; - wchar t_wchar_array_d[x]; - Enum t_enum_array_n[]; - Enum t_enum_array_d[x]; - Flag t_flag_array_n[]; - Flag t_flag_array_d[x]; - }; - """ - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - obj = cs.test() - - assert obj.t_int_array_n == obj.t_int_array_d == [] - assert obj.t_bytesint_array_n == obj.t_bytesint_array_d == [] - assert obj.t_char_array_n == obj.t_char_array_d == b"" - assert obj.t_wchar_array_n == obj.t_wchar_array_d == "" - assert obj.t_enum_array_n == obj.t_enum_array_d == [] - assert obj.t_flag_array_n == obj.t_flag_array_d == [] - assert obj.dumps() == b"\x00" * 19 - - for name in obj.fields.keys(): - assert isinstance(getattr(obj, name), BaseType) - - def test_config_flag_nocompile(cs: cstruct, compiled: bool) -> None: cdef = """ struct compiled_global diff --git a/tests/test_types_base.py b/tests/test_types_base.py index 18483fa..8ef8708 100644 --- a/tests/test_types_base.py +++ b/tests/test_types_base.py @@ -87,7 +87,7 @@ def test_eof(cs: cstruct, compiled: bool) -> None: def test_custom_array_type(cs: cstruct, compiled: bool) -> None: class CustomType(BaseType): - def __init__(self, value): + def __init__(self, value: bytes = b""): self.value = value.upper() @classmethod @@ -98,7 +98,11 @@ def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType: class ArrayType(BaseType, metaclass=ArrayMetaType): @classmethod - def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType.ArrayType: + def default(cls) -> CustomType: + return cls.type() + + @classmethod + def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType: value = cls.type._read(stream, context) if str(cls.num_entries) == "lower": value.value = value.value.lower() @@ -123,7 +127,7 @@ def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType.Arra """ cs.load(cdef, compiled=compiled) - # We just don't want to compiler to blow up with a custom type + # We don't want the compiler to blow up with a custom type assert not cs.test.__compiled__ result = cs.test(b"\x04asdf\x04asdf") diff --git a/tests/test_types_char.py b/tests/test_types_char.py index bc93993..71b0775 100644 --- a/tests/test_types_char.py +++ b/tests/test_types_char.py @@ -43,3 +43,9 @@ def test_char_eof(cs: cstruct) -> None: cs.char[None](b"AAAA") assert cs.char[0](b"") == b"" + + +def test_char_default(cs: cstruct) -> None: + assert cs.char.default() == b"\x00" + assert cs.char[4].default() == b"\x00\x00\x00\x00" + assert cs.char[None].default() == b"" diff --git a/tests/test_types_custom.py b/tests/test_types_custom.py index e9fb98f..4cf624c 100644 --- a/tests/test_types_custom.py +++ b/tests/test_types_custom.py @@ -12,6 +12,10 @@ class EtwPointer(BaseType): type: MetaType size: int | None + @classmethod + def default(cls) -> int: + return cls.cs.uint64.default() + @classmethod def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType: return cls.type._read(stream, context) @@ -41,12 +45,12 @@ def test_adding_custom_type(cs: cstruct) -> None: cs.EtwPointer.as_64bit() assert cs.EtwPointer.type is cs.uint64 assert len(cs.EtwPointer) == 8 - assert cs.EtwPointer(b"\xDE\xAD\xBE\xEF" * 2).dumps() == b"\xDE\xAD\xBE\xEF" * 2 + assert cs.EtwPointer(b"\xde\xad\xbe\xef" * 2).dumps() == b"\xde\xad\xbe\xef" * 2 cs.EtwPointer.as_32bit() assert cs.EtwPointer.type is cs.uint32 assert len(cs.EtwPointer) == 4 - assert cs.EtwPointer(b"\xDE\xAD\xBE\xEF" * 2).dumps() == b"\xDE\xAD\xBE\xEF" + assert cs.EtwPointer(b"\xde\xad\xbe\xef" * 2).dumps() == b"\xde\xad\xbe\xef" def test_using_type_in_struct(cs: cstruct) -> None: @@ -62,13 +66,26 @@ def test_using_type_in_struct(cs: cstruct) -> None: cs.load(struct_definition) cs.EtwPointer.as_64bit() - assert len(cs.test().data) == 8 - with pytest.raises(EOFError): # Input too small - cs.test(b"\xDE\xAD\xBE\xEF" * 3) + cs.test(b"\xde\xad\xbe\xef" * 3) + + cs.EtwPointer.as_32bit() + + obj = cs.test(b"\xde\xad\xbe\xef" * 3) + assert obj.data == 0xEFBEADDE + assert obj.data2 == 0xEFBEADDEEFBEADDE + assert obj.data.dumps() == b"\xde\xad\xbe\xef" + + +def test_custom_default(cs: cstruct) -> None: + cs.add_custom_type("EtwPointer", EtwPointer) + + cs.EtwPointer.as_64bit() + assert cs.EtwPointer.default() == 0 cs.EtwPointer.as_32bit() - assert len(cs.test().data) == 4 + assert cs.EtwPointer.default() == 0 - assert cs.test(b"\xDE\xAD\xBE\xEF" * 3).data.dumps() == b"\xDE\xAD\xBE\xEF" + assert cs.EtwPointer[1].default() == [0] + assert cs.EtwPointer[None].default() == [] diff --git a/tests/test_types_enum.py b/tests/test_types_enum.py index 0a9c172..9045e2b 100644 --- a/tests/test_types_enum.py +++ b/tests/test_types_enum.py @@ -407,3 +407,17 @@ def test_enum_reference_own_member(cs: cstruct, compiled: bool) -> None: assert cs.test.A == 0 assert cs.test.B == 3 assert cs.test.C == 4 + + +def test_enum_default(cs: cstruct) -> None: + cdef = """ + enum test { + A, + B, + }; + """ + cs.load(cdef) + + assert cs.test.default() == cs.test.A == cs.test(0) + assert cs.test[1].default() == [cs.test.A] + assert cs.test[None].default() == [] diff --git a/tests/test_types_flag.py b/tests/test_types_flag.py index c9cc2c7..935aac9 100644 --- a/tests/test_types_flag.py +++ b/tests/test_types_flag.py @@ -255,3 +255,17 @@ def test_flag_anonymous_struct(cs: cstruct, compiled: bool) -> None: t = test(b"\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a\x00\x00\x00") assert t.arr == [255, 0, 0, 10] + + +def test_flag_default(cs: cstruct) -> None: + cdef = """ + flag test { + A, + B, + }; + """ + cs.load(cdef) + + assert cs.test.default() == cs.test(0) + assert cs.test[1].default() == [cs.test(0)] + assert cs.test[None].default() == [] diff --git a/tests/test_types_int.py b/tests/test_types_int.py index 056d22e..329c9d7 100644 --- a/tests/test_types_int.py +++ b/tests/test_types_int.py @@ -233,7 +233,7 @@ def test_int_eof(cs: cstruct) -> None: cs.int24[None](b"\x01\x00\x00") -def test_bytesinteger_range(cs: cstruct) -> None: +def test_int_range(cs: cstruct) -> None: int8 = cs._make_int_type("int8", 1, True) uint8 = cs._make_int_type("uint9", 1, False) int16 = cs._make_int_type("int16", 2, True) @@ -336,7 +336,7 @@ def test_int_struct_unsigned(cs: cstruct, compiled: bool) -> None: assert obj.dumps() == buf -def test_bytesinteger_struct_signed_be(cs: cstruct, compiled: bool) -> None: +def test_int_struct_signed_be(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { int24 a; @@ -370,7 +370,7 @@ def test_bytesinteger_struct_signed_be(cs: cstruct, compiled: bool) -> None: assert obj.dumps() == buf -def test_bytesinteger_struct_unsigned_be(cs: cstruct, compiled: bool) -> None: +def test_int_struct_unsigned_be(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint24 a; @@ -400,3 +400,13 @@ def test_bytesinteger_struct_unsigned_be(cs: cstruct, compiled: bool) -> None: assert obj.d == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF assert obj.e == [0x4141, 0x4242] assert obj.dumps() == buf + + +def test_int_default(cs: cstruct) -> None: + assert cs.int24.default() == 0 + assert cs.uint24.default() == 0 + assert cs.int128.default() == 0 + assert cs.uint128.default() == 0 + + assert cs.int24[1].default() == [0] + assert cs.int24[None].default() == [] diff --git a/tests/test_types_leb128.py b/tests/test_types_leb128.py index d66972a..30e6dae 100644 --- a/tests/test_types_leb128.py +++ b/tests/test_types_leb128.py @@ -189,3 +189,11 @@ def test_leb128_unsigned_write_amount_written(cs: cstruct) -> None: out3 = io.BytesIO() bytes_written3 = cs.uleb128(13371337).write(out3) assert bytes_written3 == out3.tell() + + +def test_leb128_default(cs: cstruct) -> None: + assert cs.uleb128.default() == 0 + assert cs.ileb128.default() == 0 + + assert cs.uleb128[1].default() == [0] + assert cs.uleb128[None].default() == [] diff --git a/tests/test_types_packed.py b/tests/test_types_packed.py index 9c7d667..5e4d2f2 100644 --- a/tests/test_types_packed.py +++ b/tests/test_types_packed.py @@ -7,12 +7,12 @@ def test_packed_read(cs: cstruct) -> None: assert cs.uint32(b"AAAA") == 0x41414141 - assert cs.uint32(b"\xFF\xFF\xFF\xFF") == 0xFFFFFFFF + assert cs.uint32(b"\xff\xff\xff\xff") == 0xFFFFFFFF - assert cs.int32(b"\xFF\x00\x00\x00") == 255 - assert cs.int32(b"\xFF\xFF\xFF\xFF") == -1 + assert cs.int32(b"\xff\x00\x00\x00") == 255 + assert cs.int32(b"\xff\xff\xff\xff") == -1 - assert cs.float16(b"\x00\x3C") == 1.0 + assert cs.float16(b"\x00\x3c") == 1.0 assert cs.float(b"\x00\x00\x80\x3f") == 1.0 @@ -21,13 +21,13 @@ def test_packed_read(cs: cstruct) -> None: def test_packed_write(cs: cstruct) -> None: assert cs.uint32(0x41414141).dumps() == b"AAAA" - assert cs.uint32(0xFFFFFFFF).dumps() == b"\xFF\xFF\xFF\xFF" + assert cs.uint32(0xFFFFFFFF).dumps() == b"\xff\xff\xff\xff" assert cs.uint32(b"AAAA").dumps() == b"AAAA" - assert cs.int32(255).dumps() == b"\xFF\x00\x00\x00" - assert cs.int32(-1).dumps() == b"\xFF\xFF\xFF\xFF" + assert cs.int32(255).dumps() == b"\xff\x00\x00\x00" + assert cs.int32(-1).dumps() == b"\xff\xff\xff\xff" - assert cs.float16(1.0).dumps() == b"\x00\x3C" + assert cs.float16(1.0).dumps() == b"\x00\x3c" assert cs.float(1.0).dumps() == b"\x00\x00\x80\x3f" @@ -38,8 +38,8 @@ def test_packed_array_read(cs: cstruct) -> None: assert cs.uint32[2](b"AAAABBBB") == [0x41414141, 0x42424242] assert cs.uint32[None](b"AAAABBBB\x00\x00\x00\x00") == [0x41414141, 0x42424242] - assert cs.int32[2](b"\x00\x00\x00\x00\xFF\xFF\xFF\xFF") == [0, -1] - assert cs.int32[None](b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00") == [-1] + assert cs.int32[2](b"\x00\x00\x00\x00\xff\xff\xff\xff") == [0, -1] + assert cs.int32[None](b"\xff\xff\xff\xff\x00\x00\x00\x00") == [-1] assert cs.float[2](b"\x00\x00\x80\x3f\x00\x00\x00\x40") == [1.0, 2.0] assert cs.float[None](b"\x00\x00\x80\x3f\x00\x00\x00\x00") == [1.0] @@ -49,8 +49,8 @@ def test_packed_array_write(cs: cstruct) -> None: assert cs.uint32[2]([0x41414141, 0x42424242]).dumps() == b"AAAABBBB" assert cs.uint32[None]([0x41414141, 0x42424242]).dumps() == b"AAAABBBB\x00\x00\x00\x00" - assert cs.int32[2]([0, -1]).dumps() == b"\x00\x00\x00\x00\xFF\xFF\xFF\xFF" - assert cs.int32[None]([-1]).dumps() == b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00" + assert cs.int32[2]([0, -1]).dumps() == b"\x00\x00\x00\x00\xff\xff\xff\xff" + assert cs.int32[None]([-1]).dumps() == b"\xff\xff\xff\xff\x00\x00\x00\x00" assert cs.float[2]([1.0, 2.0]).dumps() == b"\x00\x00\x80\x3f\x00\x00\x00\x40" assert cs.float[None]([1.0]).dumps() == b"\x00\x00\x80\x3f\x00\x00\x00\x00" @@ -60,12 +60,12 @@ def test_packed_be_read(cs: cstruct) -> None: cs.endian = ">" assert cs.uint32(b"AAA\x00") == 0x41414100 - assert cs.uint32(b"\xFF\xFF\xFF\x00") == 0xFFFFFF00 + assert cs.uint32(b"\xff\xff\xff\x00") == 0xFFFFFF00 - assert cs.int32(b"\x00\x00\x00\xFF") == 255 - assert cs.int32(b"\xFF\xFF\xFF\xFF") == -1 + assert cs.int32(b"\x00\x00\x00\xff") == 255 + assert cs.int32(b"\xff\xff\xff\xff") == -1 - assert cs.float16(b"\x3C\x00") == 1.0 + assert cs.float16(b"\x3c\x00") == 1.0 assert cs.float(b"\x3f\x80\x00\x00") == 1.0 @@ -76,12 +76,12 @@ def test_packed_be_write(cs: cstruct) -> None: cs.endian = ">" assert cs.uint32(0x41414100).dumps() == b"AAA\x00" - assert cs.uint32(0xFFFFFF00).dumps() == b"\xFF\xFF\xFF\x00" + assert cs.uint32(0xFFFFFF00).dumps() == b"\xff\xff\xff\x00" - assert cs.int32(255).dumps() == b"\x00\x00\x00\xFF" - assert cs.int32(-1).dumps() == b"\xFF\xFF\xFF\xFF" + assert cs.int32(255).dumps() == b"\x00\x00\x00\xff" + assert cs.int32(-1).dumps() == b"\xff\xff\xff\xff" - assert cs.float16(1.0).dumps() == b"\x3C\x00" + assert cs.float16(1.0).dumps() == b"\x3c\x00" assert cs.float(1.0).dumps() == b"\x3f\x80\x00\x00" @@ -94,8 +94,8 @@ def test_packed_be_array_read(cs: cstruct) -> None: assert cs.uint32[2](b"\x00\x00\x00\x01\x00\x00\x00\x02") == [1, 2] assert cs.uint32[None](b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00") == [1, 2] - assert cs.int32[2](b"\x00\x00\x00\x01\xFF\xFF\xFF\xFE") == [1, -2] - assert cs.int32[None](b"\xFF\xFF\xFF\xFE\x00\x00\x00\x00") == [-2] + assert cs.int32[2](b"\x00\x00\x00\x01\xff\xff\xff\xfe") == [1, -2] + assert cs.int32[None](b"\xff\xff\xff\xfe\x00\x00\x00\x00") == [-2] assert cs.float[2](b"\x3f\x80\x00\x00\x40\x00\x00\x00") == [1.0, 2.0] assert cs.float[None](b"\x3f\x80\x00\x00\x00\x00\x00\x00") == [1.0] @@ -107,8 +107,8 @@ def test_packed_be_array_write(cs: cstruct) -> None: assert cs.uint32[2]([1, 2]).dumps() == b"\x00\x00\x00\x01\x00\x00\x00\x02" assert cs.uint32[None]([1, 2]).dumps() == b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00" - assert cs.int32[2]([1, -2]).dumps() == b"\x00\x00\x00\x01\xFF\xFF\xFF\xFE" - assert cs.int32[None]([-2]).dumps() == b"\xFF\xFF\xFF\xFE\x00\x00\x00\x00" + assert cs.int32[2]([1, -2]).dumps() == b"\x00\x00\x00\x01\xff\xff\xff\xfe" + assert cs.int32[None]([-2]).dumps() == b"\xff\xff\xff\xfe\x00\x00\x00\x00" assert cs.float[2]([1.0, 2.0]).dumps() == b"\x3f\x80\x00\x00\x40\x00\x00\x00" assert cs.float[None]([1.0]).dumps() == b"\x3f\x80\x00\x00\x00\x00\x00\x00" @@ -169,3 +169,20 @@ def test_packed_float_struct_be(cs: cstruct, compiled: bool) -> None: assert obj.a == 0.388916015625 assert obj.b == 42069.69140625 + + +def test_packed_default(cs: cstruct) -> None: + assert cs.int8.default() == 0 + assert cs.uint8.default() == 0 + assert cs.int16.default() == 0 + assert cs.uint16.default() == 0 + assert cs.int32.default() == 0 + assert cs.uint32.default() == 0 + assert cs.int64.default() == 0 + assert cs.uint64.default() == 0 + assert cs.float16.default() == 0.0 + assert cs.float.default() == 0.0 + assert cs.double.default() == 0.0 + + assert cs.int8[2].default() == [0, 0] + assert cs.int8[None].default() == [] diff --git a/tests/test_types_pointer.py b/tests/test_types_pointer.py index d5940d4..65be5dd 100644 --- a/tests/test_types_pointer.py +++ b/tests/test_types_pointer.py @@ -13,9 +13,10 @@ def test_pointer(cs: cstruct) -> None: cs.pointer = cs.uint8 ptr = cs._make_pointer(cs.uint8) + assert issubclass(ptr, Pointer) assert ptr.__name__ == "uint8*" - obj = ptr(b"\x01\xFF") + obj = ptr(b"\x01\xff") assert repr(obj) == "" assert obj == 1 @@ -45,7 +46,7 @@ def test_pointer_operator(cs: cstruct) -> None: cs.pointer = cs.uint8 ptr = cs._make_pointer(cs.uint8) - obj = ptr(b"\x01\x00\xFF") + obj = ptr(b"\x01\x00\xff") assert obj == 1 assert obj.dumps() == b"\x01" @@ -235,3 +236,16 @@ def test_pointer_of_pointer(cs: cstruct, compiled: bool) -> None: assert obj.ptr == 1 assert obj.ptr.dereference() == 2 assert obj.ptr.dereference().dereference() == 0x41414141 + + +def test_pointer_default(cs: cstruct) -> None: + cs.pointer = cs.uint8 + + ptr = cs._make_pointer(cs.uint8) + assert isinstance(ptr.default(), Pointer) + assert ptr.default() == 0 + assert ptr[1].default() == [0] + assert ptr[None].default() == [] + + with pytest.raises(NullPointerDereference): + ptr.default().dereference() diff --git a/tests/test_types_structure.py b/tests/test_types_structure.py index cea0efb..d2d65d0 100644 --- a/tests/test_types_structure.py +++ b/tests/test_types_structure.py @@ -1,5 +1,6 @@ import inspect from io import BytesIO +from textwrap import dedent from types import MethodType from unittest.mock import MagicMock, call, patch @@ -7,6 +8,7 @@ from dissect.cstruct.cstruct import cstruct from dissect.cstruct.exceptions import ParserError +from dissect.cstruct.types import structure from dissect.cstruct.types.base import Array, BaseType from dissect.cstruct.types.pointer import Pointer from dissect.cstruct.types.structure import Field, Structure, StructureMetaType @@ -27,6 +29,7 @@ def test_structure(TestStruct: type[Structure]) -> None: assert len(TestStruct.fields) == 2 assert TestStruct.fields["a"].name == "a" assert TestStruct.fields["b"].name == "b" + assert repr(TestStruct.fields["a"]) == "" assert TestStruct.size == 8 assert TestStruct.alignment == 4 @@ -43,7 +46,7 @@ def test_structure(TestStruct: type[Structure]) -> None: obj = TestStruct(a=1) assert obj.a == 1 - assert obj.b is None + assert obj.b == 0 assert len(obj) == 8 # Test hashing of values @@ -72,6 +75,10 @@ def test_structure_write(TestStruct: type[Structure]) -> None: assert obj.dumps() == b"\x01\x00\x00\x00\x00\x00\x00\x00" obj = TestStruct() + assert obj.a == 0 + assert obj.dumps() == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + obj.a = None assert obj.dumps() == b"\x00\x00\x00\x00\x00\x00\x00\x00" @@ -521,6 +528,17 @@ def test_structure_field_discard(cs: cstruct, compiled: bool) -> None: mock_char_new.assert_has_calls([call(cs.char, b"a"), call(cs.char, b"b")]) +def test_structure_field_duplicate(cs: cstruct) -> None: + cdef = """ + struct test { + uint8 a; + uint8 a; + }; + """ + with pytest.raises(ValueError, match="Duplicate field name: a"): + cs.load(cdef) + + def test_structure_definition_self(cs: cstruct) -> None: cdef = """ struct test { @@ -540,3 +558,206 @@ def test_align_struct_in_struct(cs: cstruct) -> None: _, kwargs = update_fields.call_args assert kwargs["align"] + + +def test_structure_default(cs: cstruct, compiled: bool) -> None: + cdef = """ + enum Enum { + a = 0, + b = 1 + }; + + flag Flag { + a = 0, + b = 1 + }; + + struct test { + uint32 t_int; + uint32 t_int_array[2]; + uint24 t_bytesint; + uint24 t_bytesint_array[2]; + char t_char; + char t_char_array[2]; + wchar t_wchar; + wchar t_wchar_array[2]; + Enum t_enum; + Enum t_enum_array[2]; + Flag t_flag; + Flag t_flag_array[2]; + uint8 *t_pointer; + uint8 *t_pointer_array[2]; + }; + + struct test_nested { + test t_struct; + test t_struct_array[2]; + }; + """ + cs.pointer = cs.uint8 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + assert cs.test() == cs.test.default() + + obj = cs.test.default() + assert obj.t_int == 0 + assert obj.t_int_array == [0, 0] + assert obj.t_bytesint == 0 + assert obj.t_bytesint_array == [0, 0] + assert obj.t_char == b"\x00" + assert obj.t_char_array == b"\x00\x00" + assert obj.t_wchar == "\x00" + assert obj.t_wchar_array == "\x00\x00" + assert obj.t_enum == cs.Enum(0) + assert obj.t_enum_array == [cs.Enum(0), cs.Enum(0)] + assert obj.t_flag == cs.Flag(0) + assert obj.t_flag_array == [cs.Flag(0), cs.Flag(0)] + assert obj.t_pointer == 0 + assert isinstance(obj.t_pointer, Pointer) + assert obj.t_pointer_array == [0, 0] + assert isinstance(obj.t_pointer_array[0], Pointer) + assert isinstance(obj.t_pointer_array[1], Pointer) + + assert obj.dumps() == b"\x00" * 57 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + assert cs.test_nested() == cs.test_nested.default() + + obj = cs.test_nested.default() + assert obj.t_struct == cs.test.default() + assert obj.t_struct_array == [cs.test.default(), cs.test.default()] + + assert obj.dumps() == b"\x00" * 171 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + +def test_structure_default_dynamic(cs: cstruct, compiled: bool) -> None: + cdef = """ + enum Enum { + a = 0, + b = 1 + }; + + flag Flag { + a = 0, + b = 1 + }; + + struct test { + uint8 x; + uint32 t_int_array_n[]; + uint32 t_int_array_d[x]; + uint24 t_bytesint_array_n[]; + uint24 t_bytesint_array_d[x]; + char t_char_array_n[]; + char t_char_array_d[x]; + wchar t_wchar_array_n[]; + wchar t_wchar_array_d[x]; + Enum t_enum_array_n[]; + Enum t_enum_array_d[x]; + Flag t_flag_array_n[]; + Flag t_flag_array_d[x]; + uint8 *t_pointer_n[]; + uint8 *t_pointer_d[x]; + }; + + struct test_nested { + uint8 x; + test t_struct_n[]; + test t_struct_array_d[x]; + }; + """ + cs.pointer = cs.uint8 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + assert cs.test() == cs.test.default() + + obj = cs.test() + assert obj.t_int_array_n == obj.t_int_array_d == [] + assert obj.t_bytesint_array_n == obj.t_bytesint_array_d == [] + assert obj.t_char_array_n == obj.t_char_array_d == b"" + assert obj.t_wchar_array_n == obj.t_wchar_array_d == "" + assert obj.t_enum_array_n == obj.t_enum_array_d == [] + assert obj.t_flag_array_n == obj.t_flag_array_d == [] + assert obj.t_pointer_n == obj.t_pointer_d == [] + + assert obj.dumps() == b"\x00" * 20 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + assert cs.test_nested() == cs.test_nested.default() + + obj = cs.test_nested.default() + assert obj.t_struct_n == obj.t_struct_array_d == [] + + assert obj.dumps() == b"\x00" * 21 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + +def test_structure_partial_initialization(cs: cstruct) -> None: + cdef = """ + struct test { + uint8 a; + uint8 b; + }; + """ + cs.load(cdef) + + obj = cs.test() + assert obj.a == 0 + assert obj.b == 0 + assert str(obj) == "" + + obj = cs.test(1, 1) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(1) + assert obj.a == 1 + assert obj.b == 0 + assert str(obj) == "" + + obj = cs.test(a=1) + assert obj.a == 1 + assert obj.b == 0 + assert str(obj) == "" + + obj = cs.test(b=1) + assert obj.a == 0 + assert obj.b == 1 + assert str(obj) == "" + + +def test_codegen_make_init() -> None: + _make__init__ = structure._make_structure__init__.__wrapped__.__wrapped__ + + result = _make__init__([f"_{n}" for n in range(5)]) + expected = """ + def __init__(self, _0 = None, _1 = None, _2 = None, _3 = None, _4 = None): + self._0 = _0 if _0 is not None else 0 + self._1 = _1 if _1 is not None else 1 + self._2 = _2 if _2 is not None else 2 + self._3 = _3 if _3 is not None else 3 + self._4 = _4 if _4 is not None else 4 + """ + assert result == dedent(expected[1:].rstrip()) + + structure._make_structure__init__.cache_clear() + assert structure._make_structure__init__.cache_info() == (0, 0, 128, 0) + result = structure._make_structure__init__(5) + assert structure._make_structure__init__.cache_info() == (0, 1, 128, 1) + cached = structure._make_structure__init__(5) + assert structure._make_structure__init__.cache_info() == (1, 1, 128, 1) + assert result is cached diff --git a/tests/test_types_union.py b/tests/test_types_union.py index 44a4957..0caf59a 100644 --- a/tests/test_types_union.py +++ b/tests/test_types_union.py @@ -3,8 +3,8 @@ import pytest from dissect.cstruct.cstruct import cstruct -from dissect.cstruct.types.base import Array -from dissect.cstruct.types.structure import Field, Union +from dissect.cstruct.types.base import Array, BaseType +from dissect.cstruct.types.structure import Field, Union, UnionProxy from .utils import verify_compiled @@ -33,12 +33,17 @@ def test_union(TestUnion: type[Union]) -> None: obj = TestUnion(1, 2) assert isinstance(obj, TestUnion) assert obj.a == 1 - assert obj.b == 2 + assert obj.b == 1 assert len(obj) == 4 obj = TestUnion(a=1) assert obj.a == 1 - assert obj.b is None + assert obj.b == 1 + assert len(obj) == 4 + + obj = TestUnion(b=1) + assert obj.a == 1 + assert obj.b == 1 assert len(obj) == 4 assert hash((obj.a, obj.b)) == hash(obj) @@ -52,6 +57,20 @@ def test_union_read(TestUnion: type[Union]) -> None: assert obj.b == 1 +def test_union_read_offset(cs: cstruct, TestUnion: type[Union]) -> None: + TestUnion.add_field("c", cs.uint8, offset=3) + + obj = TestUnion(b"\x01\x00\x00\x02") + assert obj.a == 0x02000001 + assert obj.b == 0x0001 + assert obj.c == 0x02 + + obj = TestUnion(1) + assert obj.a == 1 + assert obj.b == 1 + assert obj.c == 0 + + def test_union_write(TestUnion: type[Union]) -> None: buf = b"\x01\x00\x00\x00" obj = TestUnion(buf) @@ -68,6 +87,35 @@ def test_union_write(TestUnion: type[Union]) -> None: obj = TestUnion() assert obj.dumps() == b"\x00\x00\x00\x00" + obj = TestUnion(5, 1) + assert obj.dumps() == b"\x05\x00\x00\x00" + obj.a = None + assert obj.a == 0 + assert obj.dumps() == b"\x00\x00\x00\x00" + + obj = TestUnion(None, 1) + assert obj.a == 0 + assert obj.dumps() == b"\x00\x00\x00\x00" + + obj = TestUnion(1, None) + assert obj.b == 1 + assert obj.dumps() == b"\x01\x00\x00\x00" + + +def test_union_write_anonymous(cs: cstruct) -> None: + cdef = """ + union test { + struct { + uint32 a; + }; + }; + """ + cs.load(cdef) + + obj = cs.test(b"\x01\x00\x00\x00") + assert obj.a == 1 + assert obj.dumps() == b"\x01\x00\x00\x00" + def test_union_array_read(TestUnion: type[Union]) -> None: TestUnionArray = TestUnion[2] @@ -155,7 +203,7 @@ def test_union_cmp(TestUnion: type[Union]) -> None: def test_union_repr(TestUnion: type[Union]) -> None: obj = TestUnion(1, 2) - assert repr(obj) == f"<{TestUnion.__name__} a=0x1 b=0x2>" + assert repr(obj) == f"<{TestUnion.__name__} a=0x1 b=0x1>" def test_union_eof(TestUnion: type[Union]) -> None: @@ -214,6 +262,11 @@ def test_union_definition_nested(cs: cstruct, compiled: bool) -> None: buf = b"zomgholybeef" obj = cs.test(buf) + assert isinstance(obj.c.a, UnionProxy) + assert len(obj.c.a) == 8 + assert bytes(obj.c.a) == b"holybeef" + assert repr(obj.c.a) == "" + assert obj.magic == b"zomg" assert obj.c.a.a == 0x796C6F68 assert obj.c.a.b == 0x66656562 @@ -273,6 +326,9 @@ def test_union_definition_dynamic(cs: cstruct) -> None: assert obj.a.data == b"aaaaaaaaa" assert obj.b == 0x6161616161616109 + with pytest.raises(NotImplementedError, match="Writing dynamic unions is not yet supported"): + obj.dumps() + def test_union_update(cs: cstruct) -> None: cdef = """ @@ -317,6 +373,10 @@ def test_union_nested_update(cs: cstruct) -> None: assert obj.c.a.b == 0x48474645 assert obj.dumps() == b"1337ABCDEFGH" + obj.c.b.b = b"AAAAAAAA" + assert obj.c.a.a == 0x41414141 + assert obj.dumps() == b"1337AAAAAAAA" + def test_union_anonymous_update(cs: cstruct) -> None: cdef = """ @@ -338,3 +398,135 @@ def test_union_anonymous_update(cs: cstruct) -> None: obj = cs.test() obj.a = 0x41414141 assert obj.b == b"AAA" + + +def test_union_default(cs: cstruct) -> None: + cdef = """ + union test { + uint32 a; + char b[8]; + }; + + struct test_nested { + test t_union; + test t_union_array[2]; + }; + """ + cs.load(cdef) + + assert cs.test() == cs.test.default() + + obj = cs.test() + assert obj.a == 0 + assert obj.b == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + assert obj.dumps() == b"\x00" * 8 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + assert cs.test_nested() == cs.test_nested.default() + + obj = cs.test_nested.default() + assert obj.t_union == cs.test.default() + assert obj.t_union_array == [cs.test.default(), cs.test.default()] + + assert obj.dumps() == b"\x00" * 24 + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + +def test_union_default_dynamic(cs: cstruct) -> None: + """initialization of a dynamic union is not yet supported""" + cdef = """ + union test { + uint8 x; + char b_n[]; + char b_d[x]; + }; + + struct test_nested { + uint8 x; + test t_union_n[]; + test t_union_d[x]; + }; + """ + cs.load(cdef) + + obj = cs.test() + assert obj.x == 0 + assert obj.b_n == b"" + assert obj.b_d == b"" + + obj = cs.test_nested() + assert obj.x == 0 + assert obj.t_union_n == [] + assert obj.t_union_d == [] + + +def test_union_partial_initialization(cs: cstruct) -> None: + """partial initialization of a union should fill in the rest with appropriate values""" + cdef = """ + union test { + uint8 a; + uint8 b; + }; + """ + cs.load(cdef) + + obj = cs.test() + assert obj.a == 0 + assert obj.b == 0 + assert str(obj) == "" + + obj = cs.test(1, 1) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(1) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(a=1) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(b=1) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(a=1, b=2) + assert obj.a == 1 + assert obj.b == 1 + assert str(obj) == "" + + obj = cs.test(b=2, a=1) + assert obj.a == 2 + assert obj.b == 2 + assert str(obj) == "" + + +def test_union_partial_initialization_dynamic(cs: cstruct) -> None: + """partial initialization of a dynamic union should fill in the rest with appropriate values""" + cdef = """ + union test { + uint8 x; + char b_n[]; + char b_d[x]; + }; + """ + cs.load(cdef) + + # Default initialization should already work + cs.test() + + with pytest.raises(NotImplementedError, match="Initializing a dynamic union is not yet supported"): + cs.test(1) + + with pytest.raises(NotImplementedError, match="Initializing a dynamic union is not yet supported"): + cs.test(x=1) diff --git a/tests/test_types_void.py b/tests/test_types_void.py index 6393f9b..6e82585 100644 --- a/tests/test_types_void.py +++ b/tests/test_types_void.py @@ -2,12 +2,64 @@ from dissect.cstruct.cstruct import cstruct +from .utils import verify_compiled -def test_void(cs: cstruct) -> None: + +def test_void_read(cs: cstruct) -> None: assert not cs.void stream = io.BytesIO(b"AAAA") assert not cs.void(stream) assert stream.tell() == 0 + + +def test_void_write(cs: cstruct) -> None: assert cs.void().dumps() == b"" + + +def test_void_array_read(cs: cstruct) -> None: + assert not cs.void[4] + + stream = io.BytesIO(b"AAAA") + assert not any(cs.void[4](stream)) + assert not any(cs.void[None](stream)) + assert stream.tell() == 0 + + +def test_void_array_write(cs: cstruct) -> None: + assert cs.void[4](b"AAAA").dumps() == b"" + assert cs.void[None](b"AAAA").dumps() == b"" + + +def test_void_default(cs: cstruct) -> None: + assert cs.void() == cs.void.default() + assert not cs.void() + assert not cs.void.default() + + assert cs.void[1].default() == [cs.void()] + assert cs.void[None].default() == [] + + +def test_void_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + void a; + void b[4]; + void c[]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + stream = io.BytesIO(b"AAAA") + + obj = cs.test(stream) + assert not obj.a + assert not any(obj.b) + assert not any(obj.c) + + assert stream.tell() == 0 + + assert obj.dumps() == b"" diff --git a/tests/test_types_wchar.py b/tests/test_types_wchar.py index a8086d6..65a63d1 100644 --- a/tests/test_types_wchar.py +++ b/tests/test_types_wchar.py @@ -75,3 +75,9 @@ def test_wchar_eof(cs: cstruct) -> None: cs.wchar[None](b"A\x00A\x00A\x00A\x00") assert cs.wchar[0](b"") == "" + + +def test_wchar_default(cs: cstruct) -> None: + assert cs.wchar.default() == "\x00" + assert cs.wchar[4].default() == "\x00\x00\x00\x00" + assert cs.wchar[None].default() == ""