Skip to content

Commit 04cc5bf

Browse files
authored
Merge pull request #388 from yukinarit/support-primitive-subclass
Support (de)serialize a subclass of primitive type
2 parents 39d7a71 + 290b374 commit 04cc5bf

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

examples/primitive_subclass.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from dataclasses import dataclass
2+
from serde.json import from_json, to_json
3+
from serde import field, serde
4+
from typing import Dict, List
5+
6+
7+
class Id(str):
8+
def __str__(self) -> str:
9+
return "ID " + self
10+
11+
12+
@serde
13+
@dataclass
14+
class Foo:
15+
a: Id = field(default_factory=Id)
16+
b: Dict[Id, float] = field(default_factory=dict)
17+
c: List[Id] = field(default_factory=list)
18+
19+
20+
def main() -> None:
21+
f = Foo(Id("a"), {Id("b"): 1.0}, [Id("c")])
22+
print(f)
23+
print(type(f.a))
24+
print(type(list(f.b.keys())[0]))
25+
print(type(f.c[0]))
26+
27+
d = to_json(f)
28+
print(d)
29+
30+
ff = from_json(Foo, d)
31+
print(ff)
32+
print(type(ff.a))
33+
print(type(list(ff.b.keys())[0]))
34+
print(type(ff.c[0]))
35+
36+
37+
if __name__ == "__main__":
38+
main()

examples/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import alias
44
import any
55
import class_var
6+
import primitive_subclass
67
import collection
78
import custom_class_serializer
89
import custom_field_serializer
@@ -93,6 +94,7 @@ def run_all() -> None:
9394
run(plain_dataclass)
9495
run(plain_dataclass_class_attribute)
9596
run(msg_pack)
97+
run(primitive_subclass)
9698
if PY310:
9799
import union_operator
98100

serde/compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,20 @@ def is_enum(typ: Type[Any]) -> TypeGuard[enum.Enum]:
705705
return isinstance(typ, enum.Enum)
706706

707707

708+
def is_primitive_subclass(typ: Type[Any]) -> bool:
709+
"""
710+
Test if the type is a subclass of primitive type.
711+
712+
>>> is_primitive_subclass(str)
713+
False
714+
>>> class Str(str):
715+
... pass
716+
>>> is_primitive_subclass(Str)
717+
True
718+
"""
719+
return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)
720+
721+
708722
def is_primitive(typ: Type[Any]) -> bool:
709723
"""
710724
Test if the type is primitive.

serde/de.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_none,
4444
is_opt,
4545
is_primitive,
46+
is_primitive_subclass,
4647
is_set,
4748
is_str_serializable,
4849
is_tuple,
@@ -245,8 +246,13 @@ def wrap(cls: Type[T]) -> Type[T]:
245246
# We call deserialize and not wrap to make sure that we will use the default serde
246247
# configuration for generating the deserialization function.
247248
deserialize(typ)
248-
if is_primitive(typ) and not is_enum(typ):
249+
250+
# We don't want to add primitive class e.g "str" into the scope, but primitive
251+
# compatible types such as IntEnum and a subclass of primitives are added,
252+
# so that generated code can use those types.
253+
if is_primitive(typ) and not is_enum(typ) and not is_primitive_subclass(typ):
249254
continue
255+
250256
if is_generic(typ):
251257
g[typename(typ)] = get_origin(typ)
252258
else:
@@ -624,6 +630,7 @@ class Renderer:
624630
custom: Optional[DeserializeFunc] = None # Custom class level deserializer.
625631
import_numpy: bool = False
626632
suppress_coerce: bool = False
633+
""" Disable type coercing in codegen """
627634

628635
def render(self, arg: DeField[Any]) -> str:
629636
"""
@@ -655,8 +662,6 @@ def render(self, arg: DeField[Any]) -> str:
655662
elif is_numpy_array(arg.type):
656663
self.import_numpy = True
657664
res = deserialize_numpy_array(arg)
658-
elif is_primitive(arg.type):
659-
res = self.primitive(arg)
660665
elif is_union(arg.type):
661666
res = self.union_func(arg)
662667
elif is_str_serializable(arg.type):
@@ -669,6 +674,9 @@ def render(self, arg: DeField[Any]) -> str:
669674
res = "None"
670675
elif is_any(arg.type) or is_ellipsis(arg.type):
671676
res = arg.data
677+
elif is_primitive(arg.type):
678+
# For subclasses for primitives e.g. class FooStr(str), coercing is always enabled
679+
res = self.primitive(arg, not is_primitive_subclass(arg.type))
672680
elif isinstance(arg.type, TypeVar):
673681
index = find_generic_arg(self.cls, arg.type)
674682
res = (
@@ -876,6 +884,8 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
876884
"""
877885
Render rvalue for primitives.
878886
887+
* `suppress_coerce`: Overrides "suppress_coerce" in the Renderer's field
888+
879889
>>> Renderer('foo').render(DeField(int, 'i', datavar='data'))
880890
'coerce(int, data["i"])'
881891
@@ -890,7 +900,7 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
890900
if arg.alias:
891901
aliases = (f'"{s}"' for s in [arg.name, *arg.alias])
892902
dat = f"_get_by_aliases(data, [{','.join(aliases)}])"
893-
if self.suppress_coerce:
903+
if self.suppress_coerce and suppress_coerce:
894904
return dat
895905
else:
896906
return f"coerce({typ}, {dat})"

tests/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def param(val, typ, filter: Optional[Callable] = None):
7171
return (val, typ, filter or (lambda se, de, opt: False))
7272

7373

74-
def toml_not_supported(se, de, opt) -> bool:
74+
def toml_not_supported(se: Any, de: Any, opt: Any) -> bool:
7575
return se is to_toml
7676

7777

78+
def yaml_not_supported(se: Any, de: Any, opt: Any) -> bool:
79+
return se is to_yaml
80+
81+
7882
types: List = [
7983
param(10, int), # Primitive
8084
param("foo", str),
@@ -103,6 +107,7 @@ def toml_not_supported(se, de, opt) -> bool:
103107
param({"a": [1]}, DefaultDict[str, List[int]]),
104108
param(data.Pri(10, "foo", 100.0, True), data.Pri), # dataclass
105109
param(data.Pri(10, "foo", 100.0, True), Optional[data.Pri]),
110+
param(data.PrimitiveSubclass(data.StrSubclass("a")), data.PrimitiveSubclass, yaml_not_supported),
106111
param(None, Optional[data.Pri], toml_not_supported),
107112
param(data.Recur(data.Recur(None, None, None), None, None), data.Recur, toml_not_supported),
108113
param(

tests/data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,13 @@ class Init:
276276

277277
def __post_init__(self) -> None:
278278
self.b = self.a * 10
279+
280+
281+
class StrSubclass(str):
282+
pass
283+
284+
285+
@serde
286+
@dataclass
287+
class PrimitiveSubclass:
288+
v: StrSubclass

0 commit comments

Comments
 (0)