From ac55e29641374744d8746d2e02d1777eb3bc7dc9 Mon Sep 17 00:00:00 2001 From: Vincent Maurin Date: Thu, 27 Nov 2025 17:28:22 +0100 Subject: [PATCH] Simplify flexible versions The flexible versions is a protocol specificity for newer versions of the API. When an API is flexible, it is using more compact structures and also allow additional "dynamic" fields that could be added without the need to introduce a new API versions. This commit move the flexible versions support to the protocol layer, so it is more transparent and easy when defining Struct classes and schemas. When defining the schema, we can specify a tagged field with a tuple containing the field name and the field tag. --- CHANGES.rst | 8 + aiokafka/conn.py | 6 +- aiokafka/protocol/abstract.py | 4 +- aiokafka/protocol/admin.py | 135 +++------- aiokafka/protocol/api.py | 50 ++-- aiokafka/protocol/message.py | 38 +-- aiokafka/protocol/struct.py | 7 +- aiokafka/protocol/types.py | 319 +++++++++++------------ tests/test_protocol.py | 58 +++-- tests/test_protocol_object_conversion.py | 51 +++- tests/test_requests.py | 17 +- 11 files changed, 331 insertions(+), 362 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 2e67045c5..1eff90a88 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,14 @@ Breaking changes: `api_version` parameter has been removed from the different clients (admin/consumer/producer) (pr #1136 by @vmaurin) +New features: + +* Simplify flexible versions in schema. + Defining an API request or response schemas that should support + flexible versions (KIP-482) is now achieved by setting `FLEXIBLE_VERSION` to True. + Tagged fields could be expressed with ("name", tag) instead of just a name. + (pr #1139 by @vmaurin) + Improved Documentation: * Fix incomplete documentation for `AIOKafkaConsumer.offset_for_times`` diff --git a/aiokafka/conn.py b/aiokafka/conn.py index fdf5d8785..da0a9b730 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -428,10 +428,11 @@ def send(self, request, expect_response=True): ) from err log.debug( - "Request to %s:%d %d: %s", + "Request to %s:%d %d: %s, %s", self._host, self._port, correlation_id, + header, request_struct, ) @@ -565,10 +566,11 @@ def _handle_frame(self, resp): if not fut.done(): response = resp_type.decode(resp) log.debug( - "Response from %s:%d %d: %s", + "Response from %s:%d %d: %s, %s", self._host, self._port, correlation_id, + response_header, response, ) fut.set_result(response) diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index c466357e0..4f963710e 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -8,11 +8,11 @@ class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value: T) -> bytes: ... + def encode(cls, value: T, flexible: bool) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data: BytesIO) -> T: ... + def decode(cls, data: BytesIO, flexible: bool) -> T: ... @classmethod def repr(cls, value: T) -> str: diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 8a2d6e89c..e48012942 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -8,8 +8,6 @@ Array, Boolean, Bytes, - CompactArray, - CompactString, Float64, Int8, Int16, @@ -17,7 +15,6 @@ Int64, Schema, String, - TaggedFields, ) @@ -1453,53 +1450,48 @@ def build( class AlterPartitionReassignmentsResponse_v0(Response): API_KEY = 45 API_VERSION = 0 + FLEXIBLE_VERSION = True SCHEMA = Schema( ("throttle_time_ms", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), + ("error_message", String("utf-8")), ( "responses", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("tags", TaggedFields), + ("error_message", String("utf-8")), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) class AlterPartitionReassignmentsRequest_v0(RequestStruct): - FLEXIBLE_VERSION = True API_KEY = 45 API_VERSION = 0 + FLEXIBLE_VERSION = True RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 SCHEMA = Schema( ("timeout_ms", Int32), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("tags", TaggedFields), + ("replicas", Array(Int32)), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) @@ -1516,44 +1508,40 @@ class AlterPartitionReassignmentsRequest( def __init__( self, timeout_ms: int, - topics: list[tuple[str, tuple[int, list[int], TaggedFields], TaggedFields]], - tags: TaggedFields, + topics: list[tuple[str, tuple[int, list[int]]]], ): self._timeout_ms = timeout_ms self._topics = topics - self._tags = tags def build( self, request_struct_class: type[AlterPartitionReassignmentsRequestStruct] ) -> AlterPartitionReassignmentsRequestStruct: - return request_struct_class(self._timeout_ms, self._topics, self._tags) + return request_struct_class(self._timeout_ms, self._topics) class ListPartitionReassignmentsResponse_v0(Response): API_KEY = 46 API_VERSION = 0 + FLEXIBLE_VERSION = True SCHEMA = Schema( ("throttle_time_ms", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), + ("error_message", String("utf-8")), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("adding_replicas", CompactArray(Int32)), - ("removing_replicas", CompactArray(Int32)), - ("tags", TaggedFields), + ("replicas", Array(Int32)), + ("adding_replicas", Array(Int32)), + ("removing_replicas", Array(Int32)), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) @@ -1566,13 +1554,11 @@ class ListPartitionReassignmentsRequest_v0(RequestStruct): ("timeout_ms", Int32), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), - ("partition_index", CompactArray(Int32)), - ("tags", TaggedFields), + Array( + ("name", String("utf-8")), + ("partition_index", Array(Int32)), ), ), - ("tags", TaggedFields), ) @@ -1589,17 +1575,15 @@ class ListPartitionReassignmentsRequest( def __init__( self, timeout_ms: int, - topics: list[tuple[str, tuple[int, list[int], TaggedFields], TaggedFields]], - tags: TaggedFields, + topics: list[tuple[str, tuple[int, list[int]]]], ): self._timeout_ms = timeout_ms self._topics = topics - self._tags = tags def build( self, request_struct_class: type[ListPartitionReassignmentsRequestStruct] ) -> ListPartitionReassignmentsRequestStruct: - return request_struct_class(self._timeout_ms, self._topics, self._tags) + return request_struct_class(self._timeout_ms, self._topics) class DeleteRecordsResponse_v0(Response): @@ -1633,26 +1617,8 @@ class DeleteRecordsResponse_v1(Response): class DeleteRecordsResponse_v2(Response): API_KEY = 21 API_VERSION = 2 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ( - "topics", - CompactArray( - ("name", CompactString("utf-8")), - ( - "partitions", - CompactArray( - ("partition_index", Int32), - ("low_watermark", Int64), - ("error_code", Int16), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ) + FLEXIBLE_VERSION = True + SCHEMA = DeleteRecordsResponse_v0.SCHEMA class DeleteRecordsRequest_v0(RequestStruct): @@ -1689,25 +1655,7 @@ class DeleteRecordsRequest_v2(RequestStruct): API_VERSION = 2 FLEXIBLE_VERSION = True RESPONSE_TYPE = DeleteRecordsResponse_v2 - SCHEMA = Schema( - ( - "topics", - CompactArray( - ("name", CompactString("utf-8")), - ( - "partitions", - CompactArray( - ("partition_index", Int32), - ("offset", Int64), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ), - ), - ("timeout_ms", Int32), - ("tags", TaggedFields), - ) + SCHEMA = DeleteRecordsRequest_v0.SCHEMA DeleteRecordsRequestStruct: TypeAlias = ( @@ -1722,43 +1670,20 @@ def __init__( self, topics: Iterable[tuple[str, Iterable[tuple[int, int]]]], timeout_ms: int, - tags: dict[int, bytes] | None = None, ) -> None: self._topics = topics self._timeout_ms = timeout_ms - self._tags = tags def build( self, request_struct_class: type[DeleteRecordsRequestStruct] ) -> DeleteRecordsRequestStruct: - if request_struct_class.API_VERSION < 2: - if self._tags is not None: - raise IncompatibleBrokerVersion( - "tags requires DeleteRecordsRequest >= v2" - ) - - return request_struct_class( - [ - ( - topic, - list(partitions), - ) - for (topic, partitions) in self._topics - ], - self._timeout_ms, - ) return request_struct_class( [ ( topic, - [ - (partition, before_offset, {}) - for partition, before_offset in partitions - ], - {}, + list(partitions), ) for (topic, partitions) in self._topics ], self._timeout_ms, - self._tags or {}, ) diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index 1ac170540..7f29e0fe2 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -8,10 +8,10 @@ from aiokafka.errors import IncompatibleBrokerVersion from .struct import Struct -from .types import Array, Int16, Int32, Schema, String, TaggedFields +from .types import Array, Int16, Int32, Schema, String -class RequestHeader_v0(Struct): +class RequestHeader_v1(Struct): SCHEMA = Schema( ("api_key", Int16), ("api_version", Int16), @@ -30,25 +30,23 @@ def __init__( ) -class RequestHeader_v1(Struct): - # Flexible response / request headers end in field buffer +class RequestHeader_v2(Struct): SCHEMA = Schema( ("api_key", Int16), ("api_version", Int16), ("correlation_id", Int32), - ("client_id", String("utf-8")), - ("tags", TaggedFields), + ("client_id", String("utf-8", allow_flexible=False)), ) + FLEXIBLE_VERSION = True def __init__( self, request: RequestStruct, correlation_id: int = 0, client_id: str = "aiokafka", - tags: dict[int, bytes] | None = None, ): super().__init__( - request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} + request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -61,8 +59,8 @@ class ResponseHeader_v0(Struct): class ResponseHeader_v1(Struct): SCHEMA = Schema( ("correlation_id", Int32), - ("tags", TaggedFields), ) + FLEXIBLE_VERSION = True T = TypeVar("T", bound="RequestStruct") @@ -150,7 +148,7 @@ class RequestStruct(Struct, metaclass=abc.ABCMeta): Attributes ---------- FLEXIBLE_VERSION : bool - Use request header with flexible tags + Support flexible versions/compact format API_KEY : int The unique API key identifying the request. API_VERSION : int @@ -161,11 +159,9 @@ class RequestStruct(Struct, metaclass=abc.ABCMeta): An instance of Schema() representing the request structure. """ - FLEXIBLE_VERSION: ClassVar[bool] = False API_KEY: ClassVar[int] API_VERSION: ClassVar[int] RESPONSE_TYPE: ClassVar[type[Response]] - SCHEMA: ClassVar[Schema] def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -185,12 +181,12 @@ def to_object(self) -> dict[str, Any]: def build_request_header( self, correlation_id: int, client_id: str - ) -> RequestHeader_v0 | RequestHeader_v1: + ) -> RequestHeader_v1 | RequestHeader_v2: if self.FLEXIBLE_VERSION: - return RequestHeader_v1( + return RequestHeader_v2( self, correlation_id=correlation_id, client_id=client_id ) - return RequestHeader_v0( + return RequestHeader_v1( self, correlation_id=correlation_id, client_id=client_id ) @@ -203,15 +199,23 @@ def parse_response_header( class Response(Struct, metaclass=abc.ABCMeta): - @property - @abc.abstractmethod - def API_KEY(self) -> int: - """Integer identifier for api request/response""" + """ + Base structure for API responses. - @property - @abc.abstractmethod - def API_VERSION(self) -> int: - """Integer of api request/response version""" + Attributes + ---------- + FLEXIBLE_VERSION : bool + Support flexible versions/compact format + API_KEY : int + The unique API key identifying the response. + API_VERSION : int + Which API version the Response class is. + SCHEMA : Schema + An instance of Schema() representing the response structure. + """ + + API_KEY: ClassVar[int] + API_VERSION: ClassVar[int] def to_object(self) -> dict[str, Any]: return _to_object(self.SCHEMA, self) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 67f9d4ed7..f4d7803a9 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -132,11 +132,13 @@ def encode(self, recalc_crc: bool = True) -> bytes: self.timestamp, self.key, self.value, - ) + ), + flexible=False, ) elif version == 0: message = Message.SCHEMAS[version].encode( - (self.crc, self.magic, self.attributes, self.key, self.value) + (self.crc, self.magic, self.attributes, self.key, self.value), + flexible=False, ) else: raise ValueError(f"Unrecognized message version: {version}") @@ -144,7 +146,7 @@ def encode(self, recalc_crc: bool = True) -> bytes: return message self.crc = crc32(message[4:]) crc_field = self.BASE_FIELDS[0][1] - return crc_field.encode(self.crc) + message[4:] + return crc_field.encode(self.crc, flexible=False) + message[4:] @classmethod def decode(cls, data: io.BytesIO | bytes) -> Self: @@ -154,16 +156,16 @@ def decode(cls, data: io.BytesIO | bytes) -> Self: data = io.BytesIO(data) # Partial decode required to determine message version crc, magic, attributes = ( - cls.BASE_FIELDS[0][1].decode(data), - cls.BASE_FIELDS[1][1].decode(data), - cls.BASE_FIELDS[2][1].decode(data), + cls.BASE_FIELDS[0][1].decode(data, flexible=False), + cls.BASE_FIELDS[1][1].decode(data, flexible=False), + cls.BASE_FIELDS[2][1].decode(data, flexible=False), ) if magic == 1: magic = cast(Literal[1], magic) timestamp, key, value = ( - cls.MAGIC1_FIELDS[0][1].decode(data), - cls.MAGIC1_FIELDS[1][1].decode(data), - cls.MAGIC1_FIELDS[2][1].decode(data), + cls.MAGIC1_FIELDS[0][1].decode(data, flexible=False), + cls.MAGIC1_FIELDS[1][1].decode(data, flexible=False), + cls.MAGIC1_FIELDS[2][1].decode(data, flexible=False), ) msg = cls( value=value, @@ -176,8 +178,8 @@ def decode(cls, data: io.BytesIO | bytes) -> Self: elif magic == 0: magic = cast(Literal[0], magic) key, value = ( - cls.MAGIC0_FIELDS[0][1].decode(data), - cls.MAGIC0_FIELDS[1][1].decode(data), + cls.MAGIC0_FIELDS[0][1].decode(data, flexible=False), + cls.MAGIC0_FIELDS[1][1].decode(data, flexible=False), ) msg = cls( value=value, @@ -247,7 +249,7 @@ def encode( ) -> bytes: # RecordAccumulator encodes messagesets internally if isinstance(items, io.BytesIO): - size = Int32.decode(items) + size = Int32.decode(items, flexible=False) if prepend_size: # rewind and return all the bytes items.seek(items.tell() - 4) @@ -256,11 +258,11 @@ def encode( encoded_values: list[bytes] = [] for offset, message in items: - encoded_values.append(Int64.encode(offset)) - encoded_values.append(Bytes.encode(message)) + encoded_values.append(Int64.encode(offset, flexible=False)) + encoded_values.append(Bytes.encode(message, flexible=False)) encoded = b"".join(encoded_values) if prepend_size: - return Bytes.encode(encoded) + return Bytes.encode(encoded, flexible=False) else: return encoded @@ -274,7 +276,7 @@ def decode( if isinstance(data, bytes): data = io.BytesIO(data) if bytes_to_read is None: - bytes_to_read = Int32.decode(data) + bytes_to_read = Int32.decode(data, flexible=False) # if FetchRequest max_bytes is smaller than the available message set # the server returns partial data for the final message @@ -284,8 +286,8 @@ def decode( items: list[tuple[int, int, Message] | tuple[None, None, PartialMessage]] = [] try: while bytes_to_read: - offset = Int64.decode(raw) - msg_bytes = Bytes.decode(raw) + offset = Int64.decode(raw, flexible=False) + msg_bytes = Bytes.decode(raw, flexible=False) assert msg_bytes is not None bytes_to_read -= 8 + 4 + len(msg_bytes) items.append( diff --git a/aiokafka/protocol/struct.py b/aiokafka/protocol/struct.py index 38649d09e..5cb6d09ba 100644 --- a/aiokafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -8,6 +8,7 @@ class Struct: SCHEMA: ClassVar = Schema() + FLEXIBLE_VERSION: ClassVar[bool] = False def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == len(self.SCHEMA.fields): @@ -26,13 +27,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ) def encode(self) -> bytes: - return self.SCHEMA.encode([self.__dict__[name] for name in self.SCHEMA.names]) + return self.SCHEMA.encode( + [self.__dict__[name] for name in self.SCHEMA.names], self.FLEXIBLE_VERSION + ) @classmethod def decode(cls, data: BytesIO | bytes) -> Self: if isinstance(data, bytes): data = BytesIO(data) - return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + return cls(*cls.SCHEMA.decode(data, cls.FLEXIBLE_VERSION)) def get_item(self, name: str) -> Any: if name not in self.SCHEMA.names: diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 5a315dd71..48a1a7794 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -19,6 +19,10 @@ ValueT: TypeAlias = Union[type[AbstractType[Any]], "String", "Array", "Schema"] +TaggedFieldId: TypeAlias = tuple[str, int] + +FieldId: TypeAlias = TaggedFieldId | str + def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: @@ -47,11 +51,11 @@ class Int8(AbstractType[int]): _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(1)) @@ -60,11 +64,11 @@ class Int16(AbstractType[int]): _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(2)) @@ -73,11 +77,11 @@ class Int32(AbstractType[int]): _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(4)) @@ -86,11 +90,11 @@ class UInt32(AbstractType[int]): _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(4)) @@ -99,11 +103,11 @@ class Int64(AbstractType[int]): _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(8)) @@ -112,26 +116,39 @@ class Float64(AbstractType[float]): _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value: float) -> bytes: + def encode(cls, value: float, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> float: + def decode(cls, data: BytesIO, flexible: bool = False) -> float: return _unpack(cls._unpack, data.read(8)) class String: - def __init__(self, encoding: str = "utf-8"): + def __init__(self, encoding: str = "utf-8", allow_flexible: bool = True): self.encoding = encoding + self.allow_flexible = allow_flexible - def encode(self, value: str | None) -> bytes: + def encode(self, value: str | None, flexible: bool) -> bytes: if value is None: - return Int16.encode(-1) + return ( + UnsignedVarInt32.encode(0) + if flexible and self.allow_flexible + else Int16.encode(-1, flexible) + ) encoded_value = str(value).encode(self.encoding) - return Int16.encode(len(encoded_value)) + encoded_value + return ( + UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value + if flexible and self.allow_flexible + else Int16.encode(len(encoded_value), flexible) + encoded_value + ) - def decode(self, data: BytesIO) -> str | None: - length = Int16.decode(data) + def decode(self, data: BytesIO, flexible: bool) -> str | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible and self.allow_flexible + else Int16.decode(data, flexible) + ) if length < 0: return None value = data.read(length) @@ -146,15 +163,25 @@ def repr(cls, value: str) -> str: class Bytes(AbstractType[bytes | None]): @classmethod - def encode(cls, value: bytes | None) -> bytes: + def encode(cls, value: bytes | None, flexible: bool) -> bytes: if value is None: - return Int32.encode(-1) + return ( + UnsignedVarInt32.encode(0) if flexible else Int32.encode(-1, flexible) + ) else: - return Int32.encode(len(value)) + value + return ( + UnsignedVarInt32.encode(len(value) + 1) + value + if flexible + else Int32.encode(len(value), flexible) + value + ) @classmethod - def decode(cls, data: BytesIO) -> bytes | None: - length = Int32.decode(data) + def decode(cls, data: BytesIO, flexible: bool) -> bytes | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible + else Int32.decode(data, flexible) + ) if length < 0: return None value = data.read(length) @@ -174,33 +201,94 @@ class Boolean(AbstractType[bool]): _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value: bool) -> bytes: + def encode(cls, value: bool, flexible: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> bool: + def decode(cls, data: BytesIO, flexible: bool) -> bool: return _unpack(cls._unpack, data.read(1)) class Schema: names: tuple[str, ...] + tags: tuple[int, ...] fields: tuple[ValueT, ...] - def __init__(self, *fields: tuple[str, ValueT]): + def __init__(self, *fields: tuple[FieldId, ValueT]): if fields: - self.names, self.fields = zip(*fields, strict=False) + tagged_names, values = zip( + *( + (key, value) if isinstance(key, tuple) else ((key, None), value) + for key, value in fields + ), + strict=False, + ) + self.names = tuple(name for name, _ in tagged_names) + self.tags = tuple(tag for _, tag in tagged_names) + self.fields = tuple(values) else: - self.names, self.fields = (), () + self.names, self.tags, self.fields = (), (), () - def encode(self, item: Sequence[Any]) -> bytes: + def encode(self, item: Sequence[Any], flexible: bool) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") - return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) + return b"".join( + field.encode(item[i], flexible) + for i, field in enumerate(self.fields) + if self.tags[i] is None + ) + ( + self._encode_tagged_fields( + { + self.tags[i]: field.encode(item[i], flexible) + for i, field in enumerate(self.fields) + if self.tags[i] is not None + } + ) + if flexible + else b"" + ) def decode( - self, data: BytesIO + self, data: BytesIO, flexible: bool ) -> tuple[Any | str | None | list[Any | tuple[Any, ...]], ...]: - return tuple(field.decode(data) for field in self.fields) + result = [ + field.decode(data, flexible) if self.tags[i] is None else None + for i, field in enumerate(self.fields) + ] + if flexible: + tagged_fields = self._decode_tagged_fields(data) + for i, tag in enumerate(self.tags): + if tag is not None: + encoded_value = tagged_fields.get(tag) + if encoded_value is not None: + result[i] = self.fields[i].decode( + BytesIO(encoded_value), flexible + ) + + return tuple(result) + + @staticmethod + def _encode_tagged_fields(value: dict[int, bytes]) -> bytes: + ret = UnsignedVarInt32.encode(len(value)) + for k, v in value.items(): + assert isinstance(k, int) and k >= 0, f"Key {k} is not a positive integer" + ret += UnsignedVarInt32.encode(k) + ret += UnsignedVarInt32.encode(len(v)) + ret += v + return ret + + @staticmethod + def _decode_tagged_fields(data: BytesIO) -> dict[int, bytes]: + num_fields = UnsignedVarInt32.decode(data) + ret: dict[int, bytes] = {} + if not num_fields: + return ret + for _ in range(num_fields): + tag = UnsignedVarInt32.decode(data) + size = UnsignedVarInt32.decode(data) + val = data.read(size) + ret[tag] = val + return ret def __len__(self) -> int: return len(self.fields) @@ -227,16 +315,18 @@ def __init__(self, array_of_0: ValueT): ... @overload def __init__( - self, array_of_0: tuple[str, ValueT], *array_of: tuple[str, ValueT] + self, + array_of_0: tuple[FieldId, ValueT], + *array_of: tuple[FieldId, ValueT], ): ... def __init__( self, - array_of_0: ValueT | tuple[str, ValueT], - *array_of: tuple[str, ValueT], + array_of_0: ValueT | tuple[FieldId, ValueT], + *array_of: tuple[FieldId, ValueT], ) -> None: if array_of: - array_of_0 = cast(tuple[str, ValueT], array_of_0) + array_of_0 = cast(tuple[FieldId, ValueT], array_of_0) self.array_of = Schema(array_of_0, *array_of) else: array_of_0 = cast(ValueT, array_of_0) @@ -247,19 +337,33 @@ def __init__( else: raise ValueError("Array instantiated with no array_of type") - def encode(self, items: Sequence[Any] | None) -> bytes: + def encode(self, items: Sequence[Any] | None, flexible: bool) -> bytes: if items is None: - return Int32.encode(-1) - encoded_items = (self.array_of.encode(item) for item in items) - return b"".join( - (Int32.encode(len(items)), *encoded_items), + return ( + UnsignedVarInt32.encode(0) if flexible else Int32.encode(-1, flexible) + ) + encoded_items = (self.array_of.encode(item, flexible) for item in items) + return ( + b"".join( + (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), + ) + if flexible + else b"".join( + (Int32.encode(len(items), flexible), *encoded_items), + ) ) - def decode(self, data: BytesIO) -> list[Any | tuple[Any, ...]] | None: - length = Int32.decode(data) + def decode( + self, data: BytesIO, flexible: bool + ) -> list[Any | tuple[Any, ...]] | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible + else Int32.decode(data, flexible) + ) if length == -1: return None - return [self.array_of.decode(data) for _ in range(length)] + return [self.array_of.decode(data, flexible) for _ in range(length)] def repr(self, list_of_items: Sequence[Any] | None) -> str: if list_of_items is None: @@ -267,7 +371,7 @@ def repr(self, list_of_items: Sequence[Any] | None) -> str: return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" -class UnsignedVarInt32(AbstractType[int]): +class UnsignedVarInt32: @classmethod def decode(cls, data: BytesIO) -> int: value, i = 0, 0 @@ -293,128 +397,3 @@ def encode(cls, value: int) -> bytes: value >>= 7 ret += struct.pack("B", value) return ret - - -class VarInt32(AbstractType[int]): - @classmethod - def decode(cls, data: BytesIO) -> int: - value = UnsignedVarInt32.decode(data) - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value: int) -> bytes: - # bring it in line with the java binary repr - value &= 0xFFFFFFFF - return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) - - -class VarInt64(AbstractType[int]): - @classmethod - def decode(cls, data: BytesIO) -> int: - value, i = 0, 0 - b: int - while True: - (b,) = struct.unpack("B", data.read(1)) - if not (b & 0x80): - break - value |= (b & 0x7F) << i - i += 7 - if i > 63: - raise ValueError(f"Invalid value {value}") - value |= b << i - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value: int) -> bytes: - # bring it in line with the java binary repr - value &= 0xFFFFFFFFFFFFFFFF - v = (value << 1) ^ (value >> 63) - ret = b"" - while (v & 0xFFFFFFFFFFFFFF80) != 0: - b = (value & 0x7F) | 0x80 - ret += struct.pack("B", b) - v >>= 7 - ret += struct.pack("B", v) - return ret - - -class CompactString(String): - def decode(self, data: BytesIO) -> str | None: - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError("Buffer underrun decoding string") - return value.decode(self.encoding) - - def encode(self, value: str | None) -> bytes: - if value is None: - return UnsignedVarInt32.encode(0) - encoded_value = str(value).encode(self.encoding) - return UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value - - -class TaggedFields(AbstractType[dict[int, bytes]]): - @classmethod - def decode(cls, data: BytesIO) -> dict[int, bytes]: - num_fields = UnsignedVarInt32.decode(data) - ret: dict[int, bytes] = {} - if not num_fields: - return ret - prev_tag = -1 - for _ in range(num_fields): - tag = UnsignedVarInt32.decode(data) - if tag <= prev_tag: - raise ValueError(f"Invalid or out-of-order tag {tag}") - prev_tag = tag - size = UnsignedVarInt32.decode(data) - val = data.read(size) - ret[tag] = val - return ret - - @classmethod - def encode(cls, value: dict[int, bytes]) -> bytes: - ret = UnsignedVarInt32.encode(len(value)) - for k, v in value.items(): - # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), f"Value {v!r} is not a byte array" - assert isinstance(k, int) and k > 0, f"Key {k} is not a positive integer" - ret += UnsignedVarInt32.encode(k) - ret += v - return ret - - -class CompactBytes(AbstractType[bytes | None]): - @classmethod - def decode(cls, data: BytesIO) -> bytes | None: - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError("Buffer underrun decoding Bytes") - return value - - @classmethod - def encode(cls, value: bytes | None) -> bytes: - if value is None: - return UnsignedVarInt32.encode(0) - else: - return UnsignedVarInt32.encode(len(value) + 1) + value - - -class CompactArray(Array): - def encode(self, items: Sequence[Any] | None) -> bytes: - if items is None: - return UnsignedVarInt32.encode(0) - encoded_items = (self.array_of.encode(item) for item in items) - return b"".join( - (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), - ) - - def decode(self, data: BytesIO) -> list[Any | tuple[Any, ...]] | None: - length = UnsignedVarInt32.decode(data) - 1 - if length == -1: - return None - return [self.array_of.decode(data) for _ in range(length)] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 68239739b..fb48abbb5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -3,15 +3,14 @@ import pytest -from aiokafka.protocol.api import RequestHeader_v0, RequestStruct, Response +from aiokafka.protocol.api import RequestHeader_v1, RequestStruct, Response from aiokafka.protocol.coordination import FindCoordinatorRequest_v0 from aiokafka.protocol.fetch import FetchRequest_v0, FetchResponse_v0 from aiokafka.protocol.message import Message, MessageSet, PartialMessage from aiokafka.protocol.metadata import MetadataRequest_v0 from aiokafka.protocol.types import ( - CompactArray, - CompactBytes, - CompactString, + Array, + Bytes, Int16, Int32, Int64, @@ -192,7 +191,7 @@ def test_encode_message_header() -> None: ) req = FindCoordinatorRequest_v0("foo") - header = RequestHeader_v0(req, correlation_id=4, client_id="client3") + header = RequestHeader_v1(req, correlation_id=4, client_id="client3") assert header.encode() == expect @@ -239,7 +238,7 @@ def test_decode_fetch_response_partial() -> None: encoded = b"".join( [ Int32.encode(1), # Num Topics (Array) - String("utf-8").encode("foobar"), + String("utf-8").encode("foobar", flexible=False), Int32.encode(2), # Num Partitions (Array) Int32.encode(0), # Partition id Int16.encode(0), # Error Code @@ -328,33 +327,38 @@ def test_unsigned_varint_serde() -> None: def test_compact_data_structs() -> None: - cs = CompactString() - encoded = cs.encode(None) + cs = String() + encoded = cs.encode(None, flexible=True) assert encoded == struct.pack("B", 0) - decoded = cs.decode(io.BytesIO(encoded)) + decoded = cs.decode(io.BytesIO(encoded), flexible=True) assert decoded is None - assert cs.encode("") == b"\x01" - assert cs.decode(io.BytesIO(b"\x01")) == "" - encoded = cs.encode("foobarbaz") - assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" - - arr = CompactArray(CompactString()) - assert arr.encode(None) == b"\x00" - assert arr.decode(io.BytesIO(b"\x00")) is None - enc = arr.encode([]) + assert cs.encode("", flexible=True) == b"\x01" + assert cs.decode(io.BytesIO(b"\x01"), flexible=True) == "" + encoded = cs.encode("foobarbaz", flexible=True) + assert cs.decode(io.BytesIO(encoded), flexible=True) == "foobarbaz" + + arr = Array(String()) + assert arr.encode(None, flexible=True) == b"\x00" + assert arr.decode(io.BytesIO(b"\x00"), flexible=True) is None + enc = arr.encode([], flexible=True) assert enc == b"\x01" - assert arr.decode(io.BytesIO(enc)) == [] - encoded = arr.encode(["foo", "bar", "baz", "quux"]) - assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] + assert arr.decode(io.BytesIO(enc), flexible=True) == [] + encoded = arr.encode(["foo", "bar", "baz", "quux"], flexible=True) + assert arr.decode(io.BytesIO(encoded), flexible=True) == [ + "foo", + "bar", + "baz", + "quux", + ] - enc = CompactBytes.encode(None) + enc = Bytes.encode(None, flexible=True) assert enc == b"\x00" - assert CompactBytes.decode(io.BytesIO(b"\x00")) is None - enc = CompactBytes.encode(b"") + assert Bytes.decode(io.BytesIO(b"\x00"), flexible=True) is None + enc = Bytes.encode(b"", flexible=True) assert enc == b"\x01" - assert CompactBytes.decode(io.BytesIO(b"\x01")) == b"" - enc = CompactBytes.encode(b"foo") - assert CompactBytes.decode(io.BytesIO(enc)) == b"foo" + assert Bytes.decode(io.BytesIO(b"\x01"), flexible=True) == b"" + enc = Bytes.encode(b"foo", flexible=True) + assert Bytes.decode(io.BytesIO(enc), flexible=True) == b"foo" attr_names = [ diff --git a/tests/test_protocol_object_conversion.py b/tests/test_protocol_object_conversion.py index cdfb9705c..1d00e25fc 100644 --- a/tests/test_protocol_object_conversion.py +++ b/tests/test_protocol_object_conversion.py @@ -10,7 +10,7 @@ def _make_test_class( - klass: type[RequestStruct | Response], schema: Schema + klass: type[RequestStruct | Response], schema: Schema, flexible: bool = False ) -> type[RequestStruct | Response]: if klass is RequestStruct: @@ -19,6 +19,7 @@ class RequestTestClass(RequestStruct): API_VERSION = 0 RESPONSE_TYPE = Response SCHEMA = schema + FLEXIBLE_VERSION = flexible return RequestTestClass else: @@ -27,6 +28,7 @@ class ResponseTestClass(Response): API_KEY = 0 API_VERSION = 0 SCHEMA = schema + FLEXIBLE_VERSION = flexible return ResponseTestClass @@ -188,6 +190,53 @@ def test_with_complex_nested_array( assert myarray[1]["subarray"][0]["innertest"] == "hello" assert myarray[1]["subarray"][0]["otherinnertest"] == "hello again" + def test_flexible_version(self, superclass: type[RequestStruct | Response]) -> None: + TestClass = _make_test_class( + superclass, + Schema( + ("name", String("utf-8")), + ("myarray", Array(Int16)), + (("tagged_field1", 0), String("utf-8")), + (("tagged_field2", 42), Int16), + ( + ("tagged_field3", 53), + Array( + ("name", String("utf-8")), + (("tag1", 0), Int16), + (("tag2", 1), Int16), + ), + ), + ), + flexible=True, + ) + + tc = TestClass( + name="foo", + myarray=[1, 2, 3], + tagged_field1="bar", + tagged_field2=23, + tagged_field3=[("hello", 1, 2), ("world", 3, 4)], + ) + encoded = tc.encode() + assert tc.to_object()["name"] == "foo" + assert tc.to_object()["myarray"] == [1, 2, 3] + assert tc.to_object()["tagged_field1"] == "bar" + assert tc.to_object()["tagged_field2"] == 23 + assert tc.to_object()["tagged_field3"] == [ + {"name": "hello", "tag1": 1, "tag2": 2}, + {"name": "world", "tag1": 3, "tag2": 4}, + ] + + tc = TestClass.decode(encoded) + assert tc.to_object()["name"] == "foo" + assert tc.to_object()["myarray"] == [1, 2, 3] + assert tc.to_object()["tagged_field1"] == "bar" + assert tc.to_object()["tagged_field2"] == 23 + assert tc.to_object()["tagged_field3"] == [ + {"name": "hello", "tag1": 1, "tag2": 2}, + {"name": "world", "tag1": 3, "tag2": 4}, + ] + def test_with_metadata_response() -> None: tc = MetadataResponse_v5( diff --git a/tests/test_requests.py b/tests/test_requests.py index 0b1e7178a..64300b1fb 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -357,22 +357,22 @@ def __init__(self, expected, min_version=None, max_version=None): ], ), ( - AlterPartitionReassignmentsRequest(timeout_ms=100, topics=[], tags={}), + AlterPartitionReassignmentsRequest(timeout_ms=100, topics=[]), [ Versions(expected=IncompatibleBrokerVersion), Versions( max_version=0, - expected=AlterPartitionReassignmentsRequest_v0(100, [], {}), + expected=AlterPartitionReassignmentsRequest_v0(100, []), ), ], ), ( - ListPartitionReassignmentsRequest(timeout_ms=200, topics=[], tags={}), + ListPartitionReassignmentsRequest(timeout_ms=200, topics=[]), [ Versions(expected=IncompatibleBrokerVersion), Versions( max_version=0, - expected=ListPartitionReassignmentsRequest_v0(200, [], {}), + expected=ListPartitionReassignmentsRequest_v0(200, []), ), ], ), @@ -388,16 +388,9 @@ def __init__(self, expected, min_version=None, max_version=None): max_version=1, expected=DeleteRecordsRequest_v1([("t1", [(0, 123)])], 50), ), - ], - ), - ( - DeleteRecordsRequest(topics=[("t1", [(0, 123)])], timeout_ms=50, tags={}), - [ - Versions(expected=IncompatibleBrokerVersion), - Versions(max_version=1, expected=IncompatibleBrokerVersion), Versions( max_version=2, - expected=DeleteRecordsRequest_v2([("t1", [(0, 123, {})], {})], 50, {}), + expected=DeleteRecordsRequest_v2([("t1", [(0, 123)])], 50), ), ], ),