Skip to content

Commit

Permalink
Adding DtypeList class.
Browse files Browse the repository at this point in the history
Represents a sequence of Dtypes.
  • Loading branch information
scott-griffiths committed Oct 20, 2024
1 parent 35f463d commit 099c46f
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 5 deletions.
4 changes: 2 additions & 2 deletions bitformat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from ._bits import Bits
from ._array import Array, BitsProxy
from ._dtypes import DtypeDefinition, Register, Dtype
from ._dtypes import DtypeDefinition, Register, Dtype, DtypeList
from ._fieldtype import FieldType
from ._field import Field
from ._format import Format
Expand Down Expand Up @@ -111,7 +111,7 @@ def bool_bits2chars(_: Literal[1]):
Register().add_dtype(dt, aliases.get(dt.name, None))


__all__ = ('Bits', 'Dtype', 'Format', 'FieldType', 'Field', 'Array', 'BitsProxy', 'Expression', 'Options',
__all__ = ('Bits', 'Dtype', 'DtypeList', 'Format', 'FieldType', 'Field', 'Array', 'BitsProxy', 'Expression', 'Options',
'Register', 'Endianness', 'If', 'Pass')

# Set the __module__ of each of the types in __all__ to 'bitformat' so that they appear as bitformat.Bits instead of bitformat._bits.Bits etc.
Expand Down
82 changes: 80 additions & 2 deletions bitformat/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import functools
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Sequence
import inspect
import bitformat
from bitformat import _utils
from ._common import Expression, Endianness, byteorder


__all__ = ['Dtype', 'DtypeDefinition', 'Register']
__all__ = ['Dtype', 'DtypeList', 'DtypeDefinition', 'Register']

CACHE_SIZE = 256

Expand Down Expand Up @@ -272,6 +272,8 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}('[{self._name}{self._endianness.value}{size_str};{items_str}]')"

def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
other = Dtype.from_string(other)
if isinstance(other, Dtype):
return (self._name == other._name and
self._size == other._size and
Expand Down Expand Up @@ -562,3 +564,79 @@ def __str__(self) -> str:
hide_items = self.base_dtype.items == 0 and self.items_expression is None
items_str = '' if hide_items else (self.items_expression if self.items_expression else str(self.base_dtype.items))
return f"[{self.base_dtype.name}{self.base_dtype.endianness.value}{size_str}; {items_str}]"


class DtypeList:
"""A data type class, representing a list of concrete interpretations of binary data.
DtypeList instances are immutable. They are often created implicitly elsewhere via a token string.
>>> a = DtypeList('u12, u8, bool')
>>> b = DtypeList.from_params(['u12', 'u8', 'bool'])
"""

def __new__(cls, s: str) -> DtypeList:
return cls.from_string(s)

@classmethod
def from_params(cls, dtypes: Sequence[Dtype | str]) -> DtypeList:
x = super().__new__(cls)
x._dtypes = [dtype if isinstance(dtype, Dtype) else Dtype.from_string(dtype) for dtype in dtypes]
x._bitlength = sum(dtype.bitlength for dtype in x._dtypes)
return x

@classmethod
def from_string(cls, s: str, /) -> DtypeList:
tokens = [t.strip() for t in s.split(',')]
dtypes = [Dtype.from_string(token) for token in tokens]
return cls.from_params(dtypes)

def pack(self, values: Sequence[Any]) -> bitformat.Bits:
if len(values) != len(self):
raise ValueError(f"Expected {len(self)} values, but got {len(values)}.")
return bitformat.Bits.join(dtype.pack(value) for dtype, value in zip(self._dtypes, values))

def unpack(self, b: bitformat.Bits | str | Iterable[Any] | bytearray | bytes | memoryview, /) -> list[Any | tuple[Any]]:
"""Unpack a Bits to find its value.
The b parameter should be a Bits of the appropriate length, or an object that can be converted to a Bits.
"""
b = bitformat.Bits.from_auto(b)
if self.bitlength not in (0, len(b)):
raise ValueError(f"{self!r} is {self.bitlength} bits long, but got {len(b)} bits to unpack.")
vals = []
pos = 0
for dtype in self:
vals.append(dtype.unpack(b[pos:pos + dtype.bitlength]))
pos += dtype.bitlength
return vals

def _getbitlength(self) -> int:
return self._bitlength

bitlength = property(_getbitlength, doc="The total length of all the dtypes in bits.")

def __len__(self) -> int:
return len(self._dtypes)

def __eq__(self, other) -> bool:
if isinstance(other, DtypeList):
return self._dtypes == other._dtypes
return False

def __getitem__(self, key: int) -> Dtype:
return self._dtypes[key]

def __iter__(self):
return iter(self._dtypes)

# def __setitem__(self, key: int, value: Dtype | str):
# self._dtypes[key] = value if isinstance(value, Dtype) else Dtype.from_string(value)

def __str__(self) -> str:
return ', '.join(str(dtype) for dtype in self._dtypes)

def __repr__(self) -> str:
return f"{self.__class__.__name__}('{str(self)}')"
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The others in this section build upon them to provide more specialised structure

* :ref:`Bits <bits>` -- An immutable container for storing binary data.
* :ref:`Dtype <dtype>` -- A data type used to interpret binary data.
* :ref:`DtypeList <dtypelist>` -- A sequence of Dtypes.
* :ref:`Array <array>` -- A mutable container for contiguously allocated objects with the same `Dtype`.
* :ref:`Field <field>` -- Represents an optionally named, well-defined amount of binary data with a single data type.
* :ref:`Format <format>` -- A sequence of :class:`FieldType` objects, such as :class:`Field` or :class:`Format` instances.
Expand Down
14 changes: 14 additions & 0 deletions doc/dtypelist.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. currentmodule:: bitformat
.. _dtypelist:

DtypeList
=========

The ``DtypeList`` class represents a sequence of :class:`Dtype` objects.

----

.. autoclass:: bitformat.DtypeList
:members:
:undoc-members:
:member-order: groupwise
29 changes: 28 additions & 1 deletion tests/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import pytest
import sys
from bitformat import Dtype, Bits, Endianness
from bitformat import Dtype, Bits, Endianness, DtypeList
from bitformat._dtypes import DtypeDefinition, Register

sys.path.insert(0, '..')
Expand Down Expand Up @@ -191,3 +191,30 @@ def test_endianness_errors():
_ = Dtype.from_params('bool', endianness=Endianness.LITTLE)
with pytest.raises(ValueError):
_ = Dtype.from_params('bytes', 16, endianness=Endianness.LITTLE)


def test_dtype_list_creation():
d = DtypeList('u8, u16, u32, bool')
assert len(d) == 4
assert d.bitlength == 8 + 16 + 32 + 1

d2 = DtypeList.from_params(d)
assert d == d2
d = DtypeList.from_params(['i5', *d[1:]])
assert d[0] == 'i5'
assert d.bitlength == 5 + 16 + 32 + 1
assert d != d2

def test_dtype_list_packing():
d = DtypeList('bool, u8, f16')
a = d.pack([1, 254, 0.5])
assert a == '0b1, 0xfe, 0x3800'
with pytest.raises(ValueError):
_ = d.pack([0, 0, 0, 0])
with pytest.raises(ValueError):
_ = d.pack([0, 0])

def test_dtype_list_unpacking():
d = DtypeList('bool, u8, f16')
a = d.unpack('0b1, 0xfe, 0x3800')
assert a == [1, 254, 0.5]

0 comments on commit 099c46f

Please sign in to comment.