Skip to content

Commit

Permalink
Fix default initialization for various types (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
Schamper authored Oct 24, 2024
1 parent 09830d2 commit b5ce35d
Show file tree
Hide file tree
Showing 25 changed files with 880 additions and 252 deletions.
2 changes: 1 addition & 1 deletion dissect/cstruct/cstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _make_type(
size: int | None,
*,
alignment: int | None = None,
attrs: dict[str, Any] = None,
attrs: dict[str, Any] | None = None,
) -> type[BaseType]:
"""Create a new type class bound to this cstruct instance.
Expand Down
34 changes: 20 additions & 14 deletions dissect/cstruct/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __call__(cls, *args, **kwargs) -> MetaType | BaseType:
if len(args) == 1 and not isinstance(args[0], cls):
stream = args[0]

if hasattr(stream, "read"):
if _is_readable_type(stream):
return cls._read(stream)

if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size:
# Shortcut for char/bytes type
return type.__call__(cls, *args, **kwargs)

if isinstance(stream, (bytes, memoryview, bytearray)):
if _is_buffer_type(stream):
return cls.reads(stream)

return type.__call__(cls, *args, **kwargs)
Expand Down Expand Up @@ -83,7 +83,7 @@ def read(cls, obj: BinaryIO | bytes) -> BaseType:
Returns:
The parsed value of this type.
"""
if isinstance(obj, (bytes, memoryview, bytearray)):
if _is_buffer_type(obj):
return cls.reads(obj)

return cls._read(obj)
Expand Down Expand Up @@ -113,7 +113,7 @@ def dumps(cls, value: Any) -> bytes:
cls._write(out, value)
return out.getvalue()

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType:
"""Internal function for reading value.
Must be implemented per type.
Expand All @@ -124,7 +124,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType:
"""
raise NotImplementedError()

def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[BaseType]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[BaseType]:
"""Internal function for reading array values.
Allows type implementations to do optimized reading for their type.
Expand All @@ -145,7 +145,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non

return [cls._read(stream, context) for _ in range(count)]

def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[BaseType]:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]:
"""Internal function for reading null-terminated data.
"Null" is type specific, so must be implemented per type.
Expand Down Expand Up @@ -179,7 +179,7 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
stream: The stream to read from.
array: The array to write.
"""
return cls._write_array(stream, array + [cls()])
return cls._write_array(stream, array + [cls.default()])


class _overload:
Expand Down Expand Up @@ -225,7 +225,10 @@ class ArrayMetaType(MetaType):
num_entries: int | Expression | None
null_terminated: bool

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:
def default(cls) -> BaseType:
return type.__call__(cls, [cls.type.default()] * (cls.num_entries if isinstance(cls.num_entries, int) else 0))

def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
if cls.null_terminated:
return cls.type._read_0(stream, context)

Expand All @@ -243,11 +246,6 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:

return cls.type._read_array(stream, num, context)

def default(cls) -> BaseType:
return type.__call__(
cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)]
)


class Array(list, BaseType, metaclass=ArrayMetaType):
"""Implements a fixed or dynamically sized array type.
Expand All @@ -261,7 +259,7 @@ class Array(list, BaseType, metaclass=ArrayMetaType):
"""

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
return cls(ArrayMetaType._read(cls, stream, context))

@classmethod
Expand All @@ -275,5 +273,13 @@ def _write(cls, stream: BinaryIO, data: list[Any]) -> int:
return cls.type._write_array(stream, data)


def _is_readable_type(value: Any) -> bool:
return hasattr(value, "read")


def _is_buffer_type(value: Any) -> bool:
return isinstance(value, (bytes, memoryview, bytearray))


# As mentioned in the BaseType class, we correctly set the type here
MetaType.ArrayType = Array
8 changes: 4 additions & 4 deletions dissect/cstruct/types/char.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CharArray(bytes, BaseType, metaclass=ArrayMetaType):
"""Character array type for reading and writing byte strings."""

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> CharArray:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> CharArray:
return type.__call__(cls, ArrayMetaType._read(cls, stream, context))

@classmethod
Expand All @@ -35,11 +35,11 @@ class Char(bytes, BaseType):
ArrayType = CharArray

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
return cls._read_array(stream, 1, context)

@classmethod
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Char:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Char:
if count == 0:
return type.__call__(cls, b"")

Expand All @@ -50,7 +50,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non
return type.__call__(cls, data)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
buf = []
while True:
byte = stream.read(1)
Expand Down
10 changes: 5 additions & 5 deletions dissect/cstruct/types/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __call__(
) -> EnumMetaType:
if name is None:
if value is None:
value = cls.type()
value = cls.type.default()

if not isinstance(value, int):
# value is a parsable value
Expand Down Expand Up @@ -64,13 +64,13 @@ def __contains__(cls, value: Any) -> bool:
return True
return value in cls._value2member_map_

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Enum:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Enum:
return cls(cls.type._read(stream, context))

def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Enum]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Enum]:
return list(map(cls, cls.type._read_array(stream, count, context)))

def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Enum]:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Enum]:
return list(map(cls, cls.type._read_0(stream, context)))

def _write(cls, stream: BinaryIO, data: Enum) -> int:
Expand All @@ -82,7 +82,7 @@ def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int:

def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
data = [entry.value if isinstance(entry, Enum) else entry for entry in array]
return cls._write_array(stream, data + [cls.type()])
return cls._write_array(stream, data + [cls.type.default()])


def _fix_alias_members(cls: type[Enum]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions dissect/cstruct/types/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Int(int, BaseType):
signed: bool

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
data = stream.read(cls.size)

if len(data) != cls.size:
Expand All @@ -21,7 +21,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
result = []

while True:
Expand Down
4 changes: 2 additions & 2 deletions dissect/cstruct/types/leb128.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LEB128(int, BaseType):
signed: bool

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128:
result = 0
shift = 0
while True:
Expand All @@ -35,7 +35,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
return cls.__new__(cls, result)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128:
result = []

while True:
Expand Down
6 changes: 3 additions & 3 deletions dissect/cstruct/types/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class Packed(BaseType):
packchar: str

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed:
return cls._read_array(stream, 1, context)[0]

@classmethod
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Packed]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Packed]:
if count == EOF:
data = stream.read()
length = len(data)
Expand All @@ -39,7 +39,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non
return [cls.__new__(cls, value) for value in fmt.unpack(data)]

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed:
result = []

fmt = _struct(cls.cs.endian, cls.packchar)
Expand Down
14 changes: 9 additions & 5 deletions dissect/cstruct/types/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class Pointer(int, BaseType):
"""Pointer to some other type."""

type: MetaType
_stream: BinaryIO
_context: dict[str, Any]
_stream: BinaryIO | None
_context: dict[str, Any] | None
_value: BaseType

def __new__(cls, value: int, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Pointer:
obj = super().__new__(cls, value)
obj._stream = stream
obj._context = context
Expand Down Expand Up @@ -66,15 +66,19 @@ def __or__(self, other: int) -> Pointer:
return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context)

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
def default(cls) -> Pointer:
return cls.__new__(cls, cls.cs.pointer.default(), None, None)

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pointer:
return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context)

@classmethod
def _write(cls, stream: BinaryIO, data: int) -> int:
return cls.cs.pointer._write(stream, data)

def dereference(self) -> Any:
if self == 0:
if self == 0 or self._stream is None:
raise NullPointerDereference()

if self._value is None and not issubclass(self.type, Void):
Expand Down
Loading

0 comments on commit b5ce35d

Please sign in to comment.