Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raw element deserialization bug fixed #240

Merged
merged 2 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions pydantic_xml/element/element.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from enum import Enum
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar

from pydantic_xml.typedefs import NsMap

Expand All @@ -21,6 +21,13 @@ def tag(self) -> str:
Xml element tag.
"""

@property
@abc.abstractmethod
def nsmap(self) -> Optional[NsMap]:
"""
Xml element namespace map.
"""

@abc.abstractmethod
def is_empty(self) -> bool:
"""
Expand Down Expand Up @@ -64,6 +71,15 @@ def pop_text(self) -> Optional[str]:
:return: element text
"""

@abc.abstractmethod
def pop_tail(self) -> Optional[str]:
"""
Extracts the tail from the xml element.
All subsequent calls return `None`.

:return: element tail
"""

@abc.abstractmethod
def pop_attrib(self, name: str) -> Optional[str]:
"""
Expand All @@ -83,12 +99,22 @@ def pop_attributes(self) -> Optional[Dict[str, str]]:
"""

@abc.abstractmethod
def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlElementReader']:
def pop_elements(self) -> Tuple['XmlElementReader', ...]:
"""
Extracts all sub-elements from the xml element.
All subsequent calls return empty list.

:return: element sub-elements
"""

@abc.abstractmethod
def pop_element(self, tag: str, search_mode: 'SearchMode', remove: bool = False) -> Optional['XmlElementReader']:
"""
Extracts a sub-element from the xml element matching `tag`.

:param tag: element tag
:param search_mode: element search mode
:param remove: remove all entities from the element
:return: sub-element
"""

Expand Down Expand Up @@ -280,7 +306,7 @@ def __init__(
text: Optional[str] = None,
tail: Optional[str] = None,
attributes: Optional[Dict[str, str]] = None,
elements: Optional[List['XmlElement[NativeElement]']] = None,
elements: Optional[Iterable['XmlElement[NativeElement]']] = None,
nsmap: Optional[NsMap] = None,
sourceline: int = -1,
):
Expand All @@ -290,7 +316,7 @@ def __init__(
text=text,
tail=tail,
attrib=dict(attributes) if attributes is not None else None,
elements=elements or [],
elements=list(elements) if elements is not None else [],
next_element_idx=0,
)
self._sourceline = sourceline
Expand All @@ -303,6 +329,10 @@ def get_sourceline(self) -> int:
def tag(self) -> str:
return self._tag

@property
def nsmap(self) -> Optional[NsMap]:
return self._nsmap

def create_snapshot(self) -> 'XmlElement[NativeElement]':
element = self.__class__(
tag=self._tag,
Expand Down Expand Up @@ -359,6 +389,11 @@ def pop_text(self) -> Optional[str]:

return result

def pop_tail(self) -> Optional[str]:
result, self._state.tail = self._state.tail, None

return result

def pop_attrib(self, name: str) -> Optional[str]:
return self._state.attrib.pop(name, None) if self._state.attrib else None

Expand All @@ -367,10 +402,33 @@ def pop_attributes(self) -> Optional[Dict[str, str]]:

return result

def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlElement[NativeElement]']:
def pop_elements(self) -> Tuple['XmlElement[NativeElement]', ...]:
elements, self._state.elements = self._state.elements, []
self._state.next_element_idx = 0

return tuple(elements)

def pop_element(
self,
tag: str,
search_mode: 'SearchMode',
remove: bool = False,
) -> Optional['XmlElement[NativeElement]']:
searcher: Searcher[NativeElement] = get_searcher(search_mode)

return searcher(self._state, tag, False, True)
element = searcher(self._state, tag, False, True)
if element is not None and remove:
return self.__class__(
tag=element.tag,
nsmap=element.nsmap,
text=element.pop_text(),
tail=element.pop_tail(),
attributes=element.pop_attributes(),
elements=element.pop_elements(),
sourceline=element.get_sourceline(),
)

return element

def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> PathT['XmlElement[NativeElement]']:
assert len(path) > 0, "path can't be empty"
Expand Down
4 changes: 2 additions & 2 deletions pydantic_xml/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContex
return self._pydantic_model_metaclass_marker_callback
return super().get_metaclass_hook(fullname)

def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
return transformer.transform()
transformer.transform()


class PydanticXmlModelTransformer(PydanticModelTransformer):
Expand Down
2 changes: 1 addition & 1 deletion pydantic_xml/serializers/factories/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def deserialize(
if element is None:
return None

if (sub_element := element.pop_element(self._element_name, self._search_mode)) is not None:
if (sub_element := element.pop_element(self._element_name, self._search_mode, remove=True)) is not None:
sourcemap[loc] = sub_element.get_sourceline()
return sub_element.to_native()
else:
Expand Down
52 changes: 51 additions & 1 deletion tests/test_extra.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict
from typing import Dict, Optional

import pydantic as pd
import pytest

from pydantic_xml import BaseXmlModel, attr, element, wrapped
from pydantic_xml.element.native import ElementT
from tests.helpers import fmt_sourceline


Expand Down Expand Up @@ -230,3 +231,52 @@ class TestModel(BaseXmlModel, tag='model', extra='forbid', search_mode=search_mo
},
},
]


@pytest.mark.parametrize('search_mode', ['strict', 'ordered', 'unordered'])
def test_raw_extra_forbid(search_mode: str):
class TestModel(
BaseXmlModel,
tag='model',
extra='forbid',
arbitrary_types_allowed=True,
search_mode=search_mode,
):
field1: ElementT = element("field1")
field2: Optional[ElementT] = element("field2", default=None)

xml = '''
<model>
<field1>field value 1<nested>nested element field</nested></field1>
<field2>field value 2</field2>
<extra>undefined field<nested>nested undefined field</nested></extra>
</model>
'''
with pytest.raises(pd.ValidationError) as exc:
TestModel.from_xml(xml)

err = exc.value
assert err.title == 'TestModel'
assert err.error_count() == 2
assert err.errors() == [
{
'input': 'undefined field',
'loc': ('extra',),
'msg': f'[line {fmt_sourceline(5)}]: Extra inputs are not permitted',
'type': 'extra_forbidden',
'ctx': {
'orig': 'Extra inputs are not permitted',
'sourceline': fmt_sourceline(5),
},
},
{
'input': 'nested undefined field',
'loc': ('extra', 'nested'),
'msg': f'[line {fmt_sourceline(5)}]: Extra inputs are not permitted',
'type': 'extra_forbidden',
'ctx': {
'orig': 'Extra inputs are not permitted',
'sourceline': fmt_sourceline(5),
},
},
]
10 changes: 5 additions & 5 deletions tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_raw_primitive_element_serialization():
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
element1: ElementT = element()
element2: ElementT = element()

Expand Down Expand Up @@ -43,7 +43,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):


def test_optional_raw_primitive_element_serialization():
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
element1: Optional[ElementT] = element(default=None)
element2: ElementT = element()

Expand All @@ -66,7 +66,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):


def test_raw_element_homogeneous_collection_serialization():
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
field1: List[ElementT] = element(tag="element1")

xml = '''
Expand Down Expand Up @@ -97,7 +97,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):


def test_raw_element_heterogeneous_collection_serialization():
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
field1: Tuple[ElementT, ElementT] = element(tag="element1")

xml = '''
Expand Down Expand Up @@ -128,7 +128,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):


def test_wrapped_raw_element_serialization():
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
field1: ElementT = wrapped('wrapper', element(tag="element1"))

xml = '''
Expand Down