Skip to content

Commit

Permalink
Fix default initialization for various types
Browse files Browse the repository at this point in the history
  • Loading branch information
Schamper committed Oct 22, 2024
1 parent 09830d2 commit 6b45db8
Show file tree
Hide file tree
Showing 19 changed files with 774 additions and 211 deletions.
26 changes: 17 additions & 9 deletions dissect/cstruct/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __call__(cls, *args, **kwargs) -> MetaType | BaseType:
if len(args) == 1 and not isinstance(args[0], cls):
stream = args[0]

if hasattr(stream, "read"):
if _is_readable_type(stream):
return cls._read(stream)

if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size:
# Shortcut for char/bytes type
return type.__call__(cls, *args, **kwargs)

if isinstance(stream, (bytes, memoryview, bytearray)):
if _is_buffer_type(stream):
return cls.reads(stream)

return type.__call__(cls, *args, **kwargs)
Expand Down Expand Up @@ -83,7 +83,7 @@ def read(cls, obj: BinaryIO | bytes) -> BaseType:
Returns:
The parsed value of this type.
"""
if isinstance(obj, (bytes, memoryview, bytearray)):
if _is_buffer_type(obj):
return cls.reads(obj)

return cls._read(obj)
Expand Down Expand Up @@ -179,7 +179,7 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
stream: The stream to read from.
array: The array to write.
"""
return cls._write_array(stream, array + [cls()])
return cls._write_array(stream, array + [cls.default()])


class _overload:
Expand Down Expand Up @@ -225,6 +225,11 @@ class ArrayMetaType(MetaType):
num_entries: int | Expression | None
null_terminated: bool

def default(cls) -> BaseType:
return type.__call__(
cls, [cls.type.default() for _ in range(cls.num_entries if isinstance(cls.num_entries, int) else 0)]
)

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:
if cls.null_terminated:
return cls.type._read_0(stream, context)
Expand All @@ -243,11 +248,6 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:

return cls.type._read_array(stream, num, context)

def default(cls) -> BaseType:
return type.__call__(
cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)]
)


class Array(list, BaseType, metaclass=ArrayMetaType):
"""Implements a fixed or dynamically sized array type.
Expand Down Expand Up @@ -275,5 +275,13 @@ def _write(cls, stream: BinaryIO, data: list[Any]) -> int:
return cls.type._write_array(stream, data)


def _is_readable_type(value: Any) -> bool:
return hasattr(value, "read")


def _is_buffer_type(value: Any) -> bool:
return isinstance(value, (bytes, memoryview, bytearray))


# As mentioned in the BaseType class, we correctly set the type here
MetaType.ArrayType = Array
4 changes: 2 additions & 2 deletions dissect/cstruct/types/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __call__(
) -> EnumMetaType:
if name is None:
if value is None:
value = cls.type()
value = cls.type.default()

if not isinstance(value, int):
# value is a parsable value
Expand Down Expand Up @@ -82,7 +82,7 @@ def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int:

def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
data = [entry.value if isinstance(entry, Enum) else entry for entry in array]
return cls._write_array(stream, data + [cls.type()])
return cls._write_array(stream, data + [cls.type.default()])


def _fix_alias_members(cls: type[Enum]) -> None:
Expand Down
8 changes: 6 additions & 2 deletions dissect/cstruct/types/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Pointer(int, BaseType):
_context: dict[str, Any]
_value: BaseType

def __new__(cls, value: int, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] = None) -> Pointer:
obj = super().__new__(cls, value)
obj._stream = stream
obj._context = context
Expand Down Expand Up @@ -65,6 +65,10 @@ def __xor__(self, other: int) -> Pointer:
def __or__(self, other: int) -> Pointer:
return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context)

@classmethod
def default(cls) -> Pointer:
return cls.__new__(cls, cls.cs.pointer.default(), None, None)

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context)
Expand All @@ -74,7 +78,7 @@ def _write(cls, stream: BinaryIO, data: int) -> int:
return cls.cs.pointer._write(stream, data)

def dereference(self) -> Any:
if self == 0:
if self == 0 or self._stream is None:
raise NullPointerDereference()

if self._value is None and not issubclass(self.type, Void):
Expand Down
Loading

0 comments on commit 6b45db8

Please sign in to comment.