diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 870cb13..e7c876f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,12 @@ Changelog ========= +2.11.0 (2024-05-11) +------------------ + +- named tuple support added. See https://github.com/dapper91/pydantic-xml/issues/172 + + 2.10.0 (2024-05-09) ------------------ diff --git a/README.rst b/README.rst index 29eb883..b661d50 100644 --- a/README.rst +++ b/README.rst @@ -43,6 +43,7 @@ What is not supported? ______________________ - `dataclasses <https://docs.pydantic.dev/usage/dataclasses/>`_ +- `callable discriminators <https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator>`_ Getting started --------------- diff --git a/pydantic_xml/serializers/factories/__init__.py b/pydantic_xml/serializers/factories/__init__.py index d159961..429512c 100644 --- a/pydantic_xml/serializers/factories/__init__.py +++ b/pydantic_xml/serializers/factories/__init__.py @@ -1,2 +1,2 @@ -from . import heterogeneous, homogeneous, is_instance, mapping, model, primitive, raw, tagged_union, tuple -from . import typed_mapping, union, wrapper +from . import call, heterogeneous, homogeneous, is_instance, mapping, model, named_tuple, primitive, raw, tagged_union +from . import tuple, typed_mapping, union, wrapper diff --git a/pydantic_xml/serializers/factories/call.py b/pydantic_xml/serializers/factories/call.py new file mode 100644 index 0000000..a5f70be --- /dev/null +++ b/pydantic_xml/serializers/factories/call.py @@ -0,0 +1,16 @@ +import inspect + +from pydantic_core import core_schema as pcs + +from pydantic_xml import errors +from pydantic_xml.serializers.factories import named_tuple +from pydantic_xml.serializers.serializer import Serializer + + +def from_core_schema(schema: pcs.CallSchema, ctx: Serializer.Context) -> Serializer: + func = schema['function'] + + if inspect.isclass(func) and issubclass(func, tuple): + return named_tuple.from_core_schema(schema, ctx) + else: + raise errors.ModelError("type call is not supported") diff --git a/pydantic_xml/serializers/factories/heterogeneous.py b/pydantic_xml/serializers/factories/heterogeneous.py index 1977c6a..3ae5eb5 100644 --- a/pydantic_xml/serializers/factories/heterogeneous.py +++ b/pydantic_xml/serializers/factories/heterogeneous.py @@ -85,6 +85,7 @@ def from_core_schema(schema: pcs.TupleSchema, ctx: Serializer.Context) -> Serial SchemaTypeFamily.TYPED_MAPPING, SchemaTypeFamily.UNION, SchemaTypeFamily.IS_INSTANCE, + SchemaTypeFamily.CALL, ): raise errors.ModelFieldError( ctx.model_name, ctx.field_name, "collection item must be of primitive, model, mapping or union type", diff --git a/pydantic_xml/serializers/factories/homogeneous.py b/pydantic_xml/serializers/factories/homogeneous.py index c46a8f3..73f6243 100644 --- a/pydantic_xml/serializers/factories/homogeneous.py +++ b/pydantic_xml/serializers/factories/homogeneous.py @@ -103,6 +103,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co SchemaTypeFamily.TYPED_MAPPING, SchemaTypeFamily.UNION, SchemaTypeFamily.IS_INSTANCE, + SchemaTypeFamily.CALL, SchemaTypeFamily.TUPLE, ): raise errors.ModelFieldError( @@ -113,6 +114,7 @@ def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Co SchemaTypeFamily.MODEL, SchemaTypeFamily.UNION, SchemaTypeFamily.TUPLE, + SchemaTypeFamily.CALL, ) and ctx.entity_location is None: raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided") diff --git a/pydantic_xml/serializers/factories/named_tuple.py b/pydantic_xml/serializers/factories/named_tuple.py new file mode 100644 index 0000000..7049b2e --- /dev/null +++ b/pydantic_xml/serializers/factories/named_tuple.py @@ -0,0 +1,74 @@ +import typing +from typing import Any, Dict, List, Optional, Tuple + +from pydantic_core import core_schema as pcs + +from pydantic_xml import errors +from pydantic_xml.element import XmlElementReader, XmlElementWriter +from pydantic_xml.serializers.factories import heterogeneous +from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, Serializer +from pydantic_xml.typedefs import EntityLocation, Location + + +class ElementSerializer(Serializer): + @classmethod + def from_core_schema(cls, schema: pcs.ArgumentsSchema, ctx: Serializer.Context) -> 'ElementSerializer': + model_name = ctx.model_name + computed = ctx.field_computed + inner_serializers: List[Serializer] = [] + for argument_schema in schema['arguments_schema']: + param_schema = argument_schema['schema'] + inner_serializers.append(Serializer.parse_core_schema(param_schema, ctx)) + + return cls(model_name, computed, tuple(inner_serializers)) + + def __init__(self, model_name: str, computed: bool, inner_serializers: Tuple[Serializer, ...]): + self._inner_serializer = heterogeneous.ElementSerializer(model_name, computed, inner_serializers) + + def serialize( + self, element: XmlElementWriter, value: List[Any], encoded: List[Any], *, skip_empty: bool = False, + ) -> Optional[XmlElementWriter]: + return self._inner_serializer.serialize(element, value, encoded, skip_empty=skip_empty) + + def deserialize( + self, + element: Optional[XmlElementReader], + *, + context: Optional[Dict[str, Any]], + sourcemap: Dict[Location, int], + loc: Location, + ) -> Optional[List[Any]]: + return self._inner_serializer.deserialize(element, context=context, sourcemap=sourcemap, loc=loc) + + +def from_core_schema(schema: pcs.CallSchema, ctx: Serializer.Context) -> Serializer: + arguments_schema = typing.cast(pcs.ArgumentsSchema, schema['arguments_schema']) + for argument_schema in arguments_schema['arguments_schema']: + param_schema = argument_schema['schema'] + param_schema, ctx = Serializer.preprocess_schema(param_schema, ctx) + + param_type_family = TYPE_FAMILY.get(param_schema['type']) + if param_type_family not in ( + SchemaTypeFamily.PRIMITIVE, + SchemaTypeFamily.MODEL, + SchemaTypeFamily.MAPPING, + SchemaTypeFamily.TYPED_MAPPING, + SchemaTypeFamily.UNION, + SchemaTypeFamily.IS_INSTANCE, + SchemaTypeFamily.CALL, + ): + raise errors.ModelFieldError( + ctx.model_name, ctx.field_name, "tuple item must be of primitive, model, mapping or union type", + ) + + if param_type_family not in (SchemaTypeFamily.MODEL, SchemaTypeFamily.UNION) and ctx.entity_location is None: + raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided") + + if ctx.entity_location is EntityLocation.ELEMENT: + return ElementSerializer.from_core_schema(arguments_schema, ctx) + elif ctx.entity_location is None: + return ElementSerializer.from_core_schema(arguments_schema, ctx) + elif ctx.entity_location is EntityLocation.ATTRIBUTE: + raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "attributes of tuple types are not supported") + else: + raise AssertionError("unreachable") diff --git a/pydantic_xml/serializers/serializer.py b/pydantic_xml/serializers/serializer.py index 9e09217..14729af 100644 --- a/pydantic_xml/serializers/serializer.py +++ b/pydantic_xml/serializers/serializer.py @@ -39,6 +39,7 @@ class SchemaTypeFamily(IntEnum): DEFINITION_REF = 10 JSON_OR_PYTHON = 11 IS_INSTANCE = 12 + CALL = 13 TYPE_FAMILY = { @@ -87,6 +88,8 @@ class SchemaTypeFamily(IntEnum): 'definition-ref': SchemaTypeFamily.DEFINITION_REF, 'json-or-python': SchemaTypeFamily.JSON_OR_PYTHON, + + 'call': SchemaTypeFamily.CALL, } @@ -265,6 +268,10 @@ def select_serializer(cls, schema: pcs.CoreSchema, ctx: Context) -> 'Serializer' schema = typing.cast(pcs.IsInstanceSchema, schema) return factories.is_instance.from_core_schema(schema, ctx) + elif type_family is SchemaTypeFamily.CALL: + schema = typing.cast(pcs.CallSchema, schema) + return factories.call.from_core_schema(schema, ctx) + else: raise AssertionError("unreachable") diff --git a/pyproject.toml b/pyproject.toml index f7099de..c1742c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pydantic-xml" -version = "2.10.0" +version = "2.11.0" description = "pydantic xml extension" authors = ["Dmitry Pershin <dapper1291@gmail.com>"] license = "Unlicense" diff --git a/tests/test_named_tuple.py b/tests/test_named_tuple.py new file mode 100644 index 0000000..3e31374 --- /dev/null +++ b/tests/test_named_tuple.py @@ -0,0 +1,168 @@ +from typing import List, NamedTuple, Optional, Union + +from helpers import assert_xml_equal + +from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element + + +def test_named_tuple_of_primitives_extraction(): + class TestTuple(NamedTuple): + field1: int + field2: float + field3: str + field4: Optional[str] + + class TestModel(BaseXmlModel, tag='model1'): + elements: TestTuple = element(tag='element') + + xml = ''' + <model1> + <element>1</element> + <element>2.2</element> + <element>string3</element> + </model1> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel(elements=(1, 2.2, "string3", None)) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml(skip_empty=True) + assert_xml_equal(actual_xml, xml) + + +def test_named_tuple_of_mixed_types_extraction(): + class TestSubModel1(BaseXmlModel): + attr1: int = attr() + element1: float = element() + + class TestTuple(NamedTuple): + field1: TestSubModel1 + field2: int + + class TestModel(BaseXmlModel, tag='model1'): + submodels: TestTuple = element(tag='submodel') + + xml = ''' + <model1> + <submodel attr1="1"> + <element1>2.2</element1> + </submodel> + <submodel>1</submodel> + </model1> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel( + submodels=[ + TestSubModel1(attr1=1, element1=2.2), + 1, + ], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_list_of_named_tuples_extraction(): + class TestTuple(NamedTuple): + field1: int + field2: Optional[float] = None + + class RootModel(BaseXmlModel, tag='model'): + elements: List[TestTuple] = element(tag='element') + + xml = ''' + <model> + <element>1</element> + <element>1.1</element> + <element>2</element> + <element></element> + <element>3</element> + <element>3.3</element> + </model> + ''' + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + elements=[ + (1, 1.1), + (2, None), + (3, 3.3), + ], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_list_of_named_tuples_of_models_extraction(): + class SubModel1(RootXmlModel[str], tag='text'): + pass + + class SubModel2(RootXmlModel[int], tag='number'): + pass + + class TestTuple(NamedTuple): + field1: SubModel1 + field2: Optional[SubModel2] = None + + class RootModel(BaseXmlModel, tag='model'): + elements: List[TestTuple] + + xml = ''' + <model> + <text>text1</text> + <number>1</number> + <text>text2</text> + <text>text3</text> + <number>3</number> + </model> + ''' + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + elements=[ + (SubModel1('text1'), SubModel2(1)), + (SubModel1('text2'), None), + (SubModel1('text3'), SubModel2(3)), + ], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_primitive_union_named_tuple(): + class TestTuple(NamedTuple): + field1: Union[int, float] + field2: str + field3: Union[int, float] + + class TestModel(BaseXmlModel, tag='model'): + sublements: TestTuple = element(tag='model1') + + xml = ''' + <model> + <model1>1.1</model1> + <model1>text</model1> + <model1>1</model1> + </model> + ''' + + actual_obj = TestModel.from_xml(xml) + expected_obj = TestModel( + sublements=(float('1.1'), 'text', 1), + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml)