Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default initialization for various types #104

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dissect/cstruct/cstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _make_type(
size: int | None,
*,
alignment: int | None = None,
attrs: dict[str, Any] = None,
attrs: dict[str, Any] | None = None,
) -> type[BaseType]:
"""Create a new type class bound to this cstruct instance.

Expand Down
34 changes: 20 additions & 14 deletions dissect/cstruct/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __call__(cls, *args, **kwargs) -> MetaType | BaseType:
if len(args) == 1 and not isinstance(args[0], cls):
stream = args[0]

if hasattr(stream, "read"):
if _is_readable_type(stream):
return cls._read(stream)

if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size:
# Shortcut for char/bytes type
return type.__call__(cls, *args, **kwargs)

if isinstance(stream, (bytes, memoryview, bytearray)):
if _is_buffer_type(stream):
return cls.reads(stream)

return type.__call__(cls, *args, **kwargs)
Expand Down Expand Up @@ -83,7 +83,7 @@ def read(cls, obj: BinaryIO | bytes) -> BaseType:
Returns:
The parsed value of this type.
"""
if isinstance(obj, (bytes, memoryview, bytearray)):
if _is_buffer_type(obj):
return cls.reads(obj)

return cls._read(obj)
Expand Down Expand Up @@ -113,7 +113,7 @@ def dumps(cls, value: Any) -> bytes:
cls._write(out, value)
return out.getvalue()

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType:
"""Internal function for reading value.

Must be implemented per type.
Expand All @@ -124,7 +124,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType:
"""
raise NotImplementedError()

def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[BaseType]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[BaseType]:
"""Internal function for reading array values.

Allows type implementations to do optimized reading for their type.
Expand All @@ -145,7 +145,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non

return [cls._read(stream, context) for _ in range(count)]

def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[BaseType]:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]:
"""Internal function for reading null-terminated data.

"Null" is type specific, so must be implemented per type.
Expand Down Expand Up @@ -179,7 +179,7 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
stream: The stream to read from.
array: The array to write.
"""
return cls._write_array(stream, array + [cls()])
return cls._write_array(stream, array + [cls.default()])


class _overload:
Expand Down Expand Up @@ -225,7 +225,10 @@ class ArrayMetaType(MetaType):
num_entries: int | Expression | None
null_terminated: bool

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:
def default(cls) -> BaseType:
return type.__call__(cls, [cls.type.default()] * (cls.num_entries if isinstance(cls.num_entries, int) else 0))

def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
if cls.null_terminated:
return cls.type._read_0(stream, context)

Expand All @@ -243,11 +246,6 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:

return cls.type._read_array(stream, num, context)

def default(cls) -> BaseType:
return type.__call__(
cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)]
)


class Array(list, BaseType, metaclass=ArrayMetaType):
"""Implements a fixed or dynamically sized array type.
Expand All @@ -261,7 +259,7 @@ class Array(list, BaseType, metaclass=ArrayMetaType):
"""

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
return cls(ArrayMetaType._read(cls, stream, context))

@classmethod
Expand All @@ -275,5 +273,13 @@ def _write(cls, stream: BinaryIO, data: list[Any]) -> int:
return cls.type._write_array(stream, data)


def _is_readable_type(value: Any) -> bool:
return hasattr(value, "read")


def _is_buffer_type(value: Any) -> bool:
return isinstance(value, (bytes, memoryview, bytearray))


# As mentioned in the BaseType class, we correctly set the type here
MetaType.ArrayType = Array
8 changes: 4 additions & 4 deletions dissect/cstruct/types/char.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CharArray(bytes, BaseType, metaclass=ArrayMetaType):
"""Character array type for reading and writing byte strings."""

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> CharArray:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> CharArray:
return type.__call__(cls, ArrayMetaType._read(cls, stream, context))

@classmethod
Expand All @@ -35,11 +35,11 @@ class Char(bytes, BaseType):
ArrayType = CharArray

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
return cls._read_array(stream, 1, context)

@classmethod
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Char:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Char:
if count == 0:
return type.__call__(cls, b"")

Expand All @@ -50,7 +50,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non
return type.__call__(cls, data)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
buf = []
while True:
byte = stream.read(1)
Expand Down
10 changes: 5 additions & 5 deletions dissect/cstruct/types/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __call__(
) -> EnumMetaType:
if name is None:
if value is None:
value = cls.type()
value = cls.type.default()

if not isinstance(value, int):
# value is a parsable value
Expand Down Expand Up @@ -64,13 +64,13 @@ def __contains__(cls, value: Any) -> bool:
return True
return value in cls._value2member_map_

def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Enum:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Enum:
return cls(cls.type._read(stream, context))

def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Enum]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Enum]:
return list(map(cls, cls.type._read_array(stream, count, context)))

def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Enum]:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Enum]:
return list(map(cls, cls.type._read_0(stream, context)))

def _write(cls, stream: BinaryIO, data: Enum) -> int:
Expand All @@ -82,7 +82,7 @@ def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int:

def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
data = [entry.value if isinstance(entry, Enum) else entry for entry in array]
return cls._write_array(stream, data + [cls.type()])
return cls._write_array(stream, data + [cls.type.default()])


def _fix_alias_members(cls: type[Enum]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions dissect/cstruct/types/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Int(int, BaseType):
signed: bool

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
data = stream.read(cls.size)

if len(data) != cls.size:
Expand All @@ -21,7 +21,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
result = []

while True:
Expand Down
4 changes: 2 additions & 2 deletions dissect/cstruct/types/leb128.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LEB128(int, BaseType):
signed: bool

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128:
result = 0
shift = 0
while True:
Expand All @@ -35,7 +35,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
return cls.__new__(cls, result)

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> LEB128:
result = []

while True:
Expand Down
6 changes: 3 additions & 3 deletions dissect/cstruct/types/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class Packed(BaseType):
packchar: str

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed:
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed:
return cls._read_array(stream, 1, context)[0]

@classmethod
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Packed]:
def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Packed]:
if count == EOF:
data = stream.read()
length = len(data)
Expand All @@ -39,7 +39,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = Non
return [cls.__new__(cls, value) for value in fmt.unpack(data)]

@classmethod
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed:
def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Packed:
result = []

fmt = _struct(cls.cs.endian, cls.packchar)
Expand Down
14 changes: 9 additions & 5 deletions dissect/cstruct/types/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class Pointer(int, BaseType):
"""Pointer to some other type."""

type: MetaType
_stream: BinaryIO
_context: dict[str, Any]
_stream: BinaryIO | None
_context: dict[str, Any] | None
_value: BaseType

def __new__(cls, value: int, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Pointer:
obj = super().__new__(cls, value)
obj._stream = stream
Schamper marked this conversation as resolved.
Show resolved Hide resolved
obj._context = context
Expand Down Expand Up @@ -66,15 +66,19 @@ def __or__(self, other: int) -> Pointer:
return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context)

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer:
def default(cls) -> Pointer:
return cls.__new__(cls, cls.cs.pointer.default(), None, None)

@classmethod
def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Pointer:
return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context)

@classmethod
def _write(cls, stream: BinaryIO, data: int) -> int:
return cls.cs.pointer._write(stream, data)

def dereference(self) -> Any:
if self == 0:
if self == 0 or self._stream is None:
Schamper marked this conversation as resolved.
Show resolved Hide resolved
raise NullPointerDereference()

if self._value is None and not issubclass(self.type, Void):
Expand Down
Loading