-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathhomogeneous.py
139 lines (114 loc) · 4.75 KB
/
homogeneous.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import itertools as it
from typing import Any, Dict, List, Optional, Union
import pydantic as pd
from pydantic_core import core_schema as pcs
from pydantic_xml import errors, utils
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.serializer import TYPE_FAMILY, SchemaTypeFamily, Serializer
from pydantic_xml.typedefs import EntityLocation, Location
HomogeneousCollectionTypeSchema = Union[
pcs.TupleSchema,
pcs.ListSchema,
pcs.SetSchema,
pcs.FrozenSetSchema,
]
class ElementSerializer(Serializer):
@classmethod
def from_core_schema(cls, schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Context) -> 'ElementSerializer':
model_name = ctx.model_name
computed = ctx.field_computed
items_schema = schema['items_schema']
if isinstance(items_schema, list):
assert len(items_schema) == 1, "unexpected items schema type"
items_schema = items_schema[0]
inner_serializer = Serializer.parse_core_schema(items_schema, ctx)
return cls(model_name, computed, inner_serializer)
def __init__(self, model_name: str, computed: bool, inner_serializer: Serializer):
self._model_name = model_name
self._computed = computed
self._inner_serializer = inner_serializer
def serialize(
self,
element: XmlElementWriter,
value: List[Any],
encoded: List[Any],
*,
skip_empty: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
) -> Optional[XmlElementWriter]:
if value is None:
return element
if skip_empty and len(value) == 0:
return element
for val, enc in zip(value, encoded):
if skip_empty and val is None:
continue
self._inner_serializer.serialize(
element, val, enc, skip_empty=skip_empty, exclude_none=exclude_none, exclude_unset=exclude_unset,
)
return element
def deserialize(
self,
element: Optional[XmlElementReader],
*,
context: Optional[Dict[str, Any]],
sourcemap: Dict[Location, int],
loc: Location,
) -> Optional[List[Any]]:
if self._computed:
return None
if element is None:
return None
serializer = self._inner_serializer
result: List[Any] = []
item_errors: Dict[Union[None, str, int], pd.ValidationError] = {}
for idx in it.count():
try:
value = serializer.deserialize(element, context=context, sourcemap=sourcemap, loc=loc + (idx,))
if value is None:
break
except pd.ValidationError as err:
item_errors[idx] = err
else:
result.append(value)
if item_errors:
raise utils.into_validation_error(title=self._model_name, errors_map=item_errors)
return result or None
def from_core_schema(schema: HomogeneousCollectionTypeSchema, ctx: Serializer.Context) -> Serializer:
items_schema = schema['items_schema']
if isinstance(items_schema, list):
assert len(items_schema) == 1, "unexpected items schema type"
items_schema = items_schema[0]
items_schema, ctx = Serializer.preprocess_schema(items_schema, ctx)
items_type_family = TYPE_FAMILY.get(items_schema['type'])
if items_type_family not in (
SchemaTypeFamily.PRIMITIVE,
SchemaTypeFamily.MODEL,
SchemaTypeFamily.MAPPING,
SchemaTypeFamily.TYPED_MAPPING,
SchemaTypeFamily.UNION,
SchemaTypeFamily.TAGGED_UNION,
SchemaTypeFamily.IS_INSTANCE,
SchemaTypeFamily.CALL,
SchemaTypeFamily.TUPLE,
):
raise errors.ModelFieldError(
ctx.model_name, ctx.field_name, "collection item must be of primitive, model, mapping, union or tuple type",
)
if items_type_family not in (
SchemaTypeFamily.MODEL,
SchemaTypeFamily.UNION,
SchemaTypeFamily.TAGGED_UNION,
SchemaTypeFamily.TUPLE,
SchemaTypeFamily.CALL,
) and ctx.entity_location is None:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "entity name is not provided")
if ctx.entity_location is EntityLocation.ELEMENT:
return ElementSerializer.from_core_schema(schema, ctx)
elif ctx.entity_location is None:
return ElementSerializer.from_core_schema(schema, ctx)
elif ctx.entity_location is EntityLocation.ATTRIBUTE:
raise errors.ModelFieldError(ctx.model_name, ctx.field_name, "attributes of collection types are not supported")
else:
raise AssertionError("unreachable")