Skip to content

Commit cdbd28f

Browse files
authored
Merge pull request #377 from yukinarit/feature/fix-nested-generic-classes
Fix nested generic class deserialization
2 parents 9272f87 + 5b843f5 commit cdbd28f

File tree

5 files changed

+138
-30
lines changed

5 files changed

+138
-30
lines changed

serde/compat.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,28 @@ def is_ellipsis(typ: Any) -> bool:
812812
return typ is Ellipsis
813813

814814

815+
def get_type_var_names(cls: Type[Any]) -> Optional[List[str]]:
816+
"""
817+
Get type argument names of a generic class.
818+
819+
>>> T = typing.TypeVar('T')
820+
>>> class GenericFoo(typing.Generic[T]):
821+
... pass
822+
>>> get_type_var_names(GenericFoo)
823+
['T']
824+
>>> get_type_var_names(int)
825+
"""
826+
bases = getattr(cls, "__orig_bases__", ())
827+
if not bases:
828+
return None
829+
830+
type_arg_names: List[str] = []
831+
for base in bases:
832+
type_arg_names.extend(arg.__name__ for arg in get_args(base))
833+
834+
return type_arg_names
835+
836+
815837
def find_generic_arg(cls: Type[Any], field: TypeVar) -> int:
816838
"""
817839
Find a type in generic parameters.
@@ -843,26 +865,47 @@ def find_generic_arg(cls: Type[Any], field: TypeVar) -> int:
843865
return -1
844866

845867

846-
def get_generic_arg(typ: Any, index: int) -> Any:
868+
def get_generic_arg(
869+
typ: Any,
870+
maybe_generic_type_vars: Optional[List[str]],
871+
variable_type_args: Optional[List[str]],
872+
index: int,
873+
) -> Any:
847874
"""
848-
Get generic type argument by index.
875+
Get generic type argument.
849876
850877
>>> T = typing.TypeVar('T')
851878
>>> U = typing.TypeVar('U')
852879
>>> class GenericFoo(typing.Generic[T, U]):
853880
... pass
854-
>>> get_generic_arg(GenericFoo[int, str], 0).__name__
881+
>>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 0).__name__
855882
'int'
856-
>>> get_generic_arg(GenericFoo[int, str], 1).__name__
883+
>>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['T', 'U'], 1).__name__
884+
'str'
885+
>>> get_generic_arg(GenericFoo[int, str], ['T', 'U'], ['U'], 0).__name__
857886
'str'
858887
"""
859-
if not is_generic(typ):
888+
if not is_generic(typ) or maybe_generic_type_vars is None or variable_type_args is None:
860889
return typing.Any
861-
else:
862-
args = get_args(typ)
863-
if index + 1 > len(args):
864-
return typing.Any
865-
return args[index]
890+
891+
args = get_args(typ)
892+
893+
if len(args) != len(maybe_generic_type_vars):
894+
raise SerdeError(
895+
f"Number of type args for {typ} does not match number of generic type vars: "
896+
f"\n type args: {args}\n type_vars: {maybe_generic_type_vars}"
897+
)
898+
899+
# Get the name of the type var used for this field in the parent class definition
900+
type_var_name = variable_type_args[index]
901+
902+
try:
903+
# Find the slot of that type var in the original generic class definition
904+
orig_index = maybe_generic_type_vars.index(type_var_name)
905+
except ValueError:
906+
return typing.Any
907+
908+
return args[orig_index]
866909

867910

868911
def has_default(field: dataclasses.Field) -> bool:

serde/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ class Field:
398398
deserializer: Optional[Func] = None # Custom field deserializer.
399399
flatten: Optional[FlattenOpts] = None
400400
parent: Optional[Type[Any]] = None
401+
type_args: Optional[List[str]] = None
401402

402403
@classmethod
403404
def from_dataclass(cls, f: dataclasses.Field, parent: Optional[Type[Any]] = None) -> "Field":

serde/de.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
import typing
1111
from dataclasses import dataclass, is_dataclass
12-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic, overload
12+
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, overload
1313

1414
import jinja2
1515
from typing_extensions import Type, dataclass_transform
@@ -24,6 +24,7 @@
2424
get_args,
2525
get_generic_arg,
2626
get_origin,
27+
get_type_var_names,
2728
has_default,
2829
has_default_factory,
2930
is_any,
@@ -602,6 +603,10 @@ def render(self, arg: DeField) -> str:
602603
"""
603604
if arg.deserializer and arg.deserializer.inner is not default_deserializer:
604605
res = self.custom_field_deserializer(arg)
606+
elif is_generic(arg.type):
607+
arg.type_args = get_args(arg.type)
608+
arg.type = get_origin(arg.type)
609+
res = self.render(arg)
605610
elif is_dataclass(arg.type):
606611
res = self.dataclass(arg)
607612
elif is_opt(arg.type):
@@ -639,12 +644,10 @@ def render(self, arg: DeField) -> str:
639644
elif isinstance(arg.type, TypeVar):
640645
index = find_generic_arg(self.cls, arg.type)
641646
res = (
642-
f"from_obj(get_generic_arg(maybe_generic, {index}), "
643-
f" {arg.data}, named={not arg.iterbased}, reuse_instances=reuse_instances)"
647+
f"from_obj(get_generic_arg(maybe_generic, maybe_generic_type_vars, "
648+
f"variable_type_args, {index}), {arg.data}, named={not arg.iterbased}, "
649+
"reuse_instances=reuse_instances)"
644650
)
645-
elif is_generic(arg.type):
646-
arg.type = get_origin(arg.type)
647-
res = self.render(arg)
648651
elif is_literal(arg.type):
649652
res = self.literal(arg)
650653
else:
@@ -690,7 +693,12 @@ def dataclass(self, arg: DeField) -> str:
690693
else:
691694
var = arg.datavar
692695

693-
opts = "maybe_generic=maybe_generic, reuse_instances=reuse_instances"
696+
type_args_str = [str(t).lstrip("~") for t in arg.type_args] if arg.type_args else None
697+
698+
opts = (
699+
"maybe_generic=maybe_generic, maybe_generic_type_vars=maybe_generic_type_vars, "
700+
f"variable_type_args={type_args_str}, reuse_instances=reuse_instances"
701+
)
694702

695703
if arg.is_self_referencing():
696704
class_name = "cls"
@@ -718,6 +726,7 @@ def opt(self, arg: DeField) -> str:
718726
... o: Optional[List[int]]
719727
>>> Renderer('foo').render(DeField(Optional[Foo], 'f', datavar='data'))
720728
'(Foo.__serde__.funcs[\\'foo\\'](data=data["f"], maybe_generic=maybe_generic, \
729+
maybe_generic_type_vars=maybe_generic_type_vars, variable_type_args=None, \
721730
reuse_instances=reuse_instances)) if data.get("f") is not None else None'
722731
"""
723732
value = arg[0]
@@ -771,13 +780,16 @@ def tuple(self, arg: DeField) -> str:
771780
>>> Renderer('foo').render(DeField(Tuple[str, int, List[int], Foo], 'd', datavar='data'))
772781
'(coerce(str, data["d"][0]), coerce(int, data["d"][1]), \
773782
[coerce(int, v) for v in data["d"][2]], \
774-
Foo.__serde__.funcs[\\'foo\\'](data=data["d"][3], maybe_generic=maybe_generic, reuse_instances=reuse_instances),)'
783+
Foo.__serde__.funcs[\\'foo\\'](data=data["d"][3], maybe_generic=maybe_generic, \
784+
maybe_generic_type_vars=maybe_generic_type_vars, variable_type_args=None, \
785+
reuse_instances=reuse_instances),)'
775786
776787
>>> field = DeField(Tuple[str, int, List[int], Foo], 'd', datavar='data', index=0, iterbased=True)
777788
>>> Renderer('foo').render(field)
778789
"(coerce(str, data[0][0]), coerce(int, data[0][1]), \
779790
[coerce(int, v) for v in data[0][2]], Foo.__serde__.funcs['foo'](data=data[0][3], \
780-
maybe_generic=maybe_generic, reuse_instances=reuse_instances),)"
791+
maybe_generic=maybe_generic, maybe_generic_type_vars=maybe_generic_type_vars, \
792+
variable_type_args=None, reuse_instances=reuse_instances),)"
781793
"""
782794
if is_bare_tuple(arg.type) or is_variable_tuple(arg.type):
783795
return f"tuple({arg.data})"
@@ -799,9 +811,10 @@ def dict(self, arg: DeField) -> str:
799811
>>> @deserialize
800812
... class Foo: pass
801813
>>> Renderer('foo').render(DeField(Dict[Foo, List[Foo]], 'f', datavar='data'))
802-
'{Foo.__serde__.funcs[\\'foo\\'](data=k, maybe_generic=maybe_generic, reuse_instances=reuse_instances): \
803-
[Foo.__serde__.funcs[\\'foo\\'](data=v, maybe_generic=maybe_generic, reuse_instances=reuse_instances) for v in v] \
804-
for k, v in data["f"].items()}'
814+
'{Foo.__serde__.funcs[\\'foo\\'](data=k, maybe_generic=maybe_generic, \
815+
maybe_generic_type_vars=maybe_generic_type_vars, variable_type_args=None, reuse_instances=reuse_instances): \
816+
[Foo.__serde__.funcs[\\'foo\\'](data=v, maybe_generic=maybe_generic, maybe_generic_type_vars=maybe_generic_type_vars, \
817+
variable_type_args=None, reuse_instances=reuse_instances) for v in v] for k, v in data["f"].items()}'
805818
"""
806819
if is_bare_dict(arg.type):
807820
return arg.data
@@ -898,13 +911,16 @@ def renderable(f: DeField) -> bool:
898911

899912
def render_from_iter(cls: Type[Any], custom: Optional[DeserializeFunc] = None, type_check: TypeCheck = NoCheck) -> str:
900913
template = """
901-
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
914+
def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None,
915+
variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
902916
if reuse_instances is Ellipsis:
903917
reuse_instances = {{serde_scope.reuse_instances_default}}
904918
905919
if data is None:
906920
return None
907921
922+
maybe_generic_type_vars = maybe_generic_type_vars or {{cls_type_vars}}
923+
908924
{% for f in fields %}
909925
__{{f.name}} = {{f|arg(loop.index-1)|rvalue}}
910926
{% endfor %}
@@ -924,7 +940,12 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
924940
env.filters.update({"rvalue": renderer.render})
925941
env.filters.update({"arg": to_iter_arg})
926942
fields = list(filter(renderable, defields(cls)))
927-
res = env.get_template("iter").render(func=FROM_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=fields)
943+
res = env.get_template("iter").render(
944+
func=FROM_ITER,
945+
serde_scope=getattr(cls, SERDE_SCOPE),
946+
fields=fields,
947+
cls_type_vars=get_type_var_names(cls),
948+
)
928949

929950
if renderer.import_numpy:
930951
res = "import numpy\n" + res
@@ -939,14 +960,16 @@ def render_from_dict(
939960
type_check: TypeCheck = NoCheck,
940961
) -> str:
941962
template = """
942-
def {{func}}(cls=cls, maybe_generic=None, data=None,
943-
reuse_instances = {{serde_scope.reuse_instances_default}}):
963+
def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None,
964+
variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
944965
if reuse_instances is Ellipsis:
945966
reuse_instances = {{serde_scope.reuse_instances_default}}
946967
947968
if data is None:
948969
return None
949970
971+
maybe_generic_type_vars = maybe_generic_type_vars or {{cls_type_vars}}
972+
950973
{% for f in fields %}
951974
__{{f.name}} = {{f|arg(loop.index-1)|rvalue}}
952975
{% endfor %}
@@ -973,7 +996,11 @@ def {{func}}(cls=cls, maybe_generic=None, data=None,
973996
env.filters.update({"arg": functools.partial(to_arg, rename_all=rename_all)})
974997
fields = list(filter(renderable, defields(cls)))
975998
res = env.get_template("dict").render(
976-
func=FROM_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=fields, type_check=type_check
999+
func=FROM_DICT,
1000+
serde_scope=getattr(cls, SERDE_SCOPE),
1001+
fields=fields,
1002+
type_check=type_check,
1003+
cls_type_vars=get_type_var_names(cls),
9771004
)
9781005

9791006
if renderer.import_numpy:
@@ -984,7 +1011,8 @@ def {{func}}(cls=cls, maybe_generic=None, data=None,
9841011

9851012
def render_union_func(cls: Type[Any], union_args: List[Type[Any]], tagging: Tagging = DefaultTagging) -> str:
9861013
template = """
987-
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
1014+
def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None,
1015+
variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
9881016
errors = []
9891017
{% for t in union_args %}
9901018
try:
@@ -1042,7 +1070,8 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
10421070

10431071
def render_literal_func(cls: Type[Any], literal_args: List[Any], tagging: Tagging = DefaultTagging) -> str:
10441072
template = """
1045-
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
1073+
def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None,
1074+
variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
10461075
if data in ({%- for v in literal_args -%}{{v|repr}},{%- endfor -%}):
10471076
return data
10481077
raise SerdeError("Can not deserialize " + repr(data) + " as {{literal_name}}.")

tests/common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Callable, DefaultDict, Dict, FrozenSet, Generic, List, NewType, Optional, Set, Tuple, TypeVar
1010

1111
import more_itertools
12+
1213
from serde import from_dict, from_tuple, serde, to_dict, to_tuple
1314
from serde.json import from_json, to_json
1415
from serde.msgpack import from_msgpack, to_msgpack
@@ -39,13 +40,26 @@
3940

4041
U = TypeVar("U")
4142

43+
V = TypeVar("V")
44+
4245

4346
@serde
4447
class GenericClass(Generic[T, U]):
4548
a: T
4649
b: U
4750

4851

52+
@serde
53+
class Inner(Generic[T]):
54+
c: T
55+
56+
57+
@serde
58+
class NestedGenericClass(Generic[U, V]):
59+
a: U
60+
b: Inner[V]
61+
62+
4963
def param(val, typ, filter: Optional[Callable] = None):
5064
"""
5165
Create a test parameter
@@ -100,6 +114,7 @@ def toml_not_supported(se, de, opt) -> bool:
100114
param(10, NewType("Int", int)), # NewType
101115
param({"a": 1}, Any), # Any
102116
param(GenericClass[str, int]("foo", 10), GenericClass[str, int]), # Generic
117+
param(NestedGenericClass[str, int]("foo", Inner[int](10)), NestedGenericClass[str, int]),
103118
param(pathlib.Path("/tmp/foo"), pathlib.Path), # Extended types
104119
param(pathlib.Path("/tmp/foo"), Optional[pathlib.Path]),
105120
param(None, Optional[pathlib.Path], toml_not_supported),

tests/test_compat.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import sys
22
from dataclasses import dataclass
33
from datetime import datetime
4-
from typing import Dict, Generic, List, NewType, Optional, Set, Tuple, TypeVar, Union
4+
from typing import Any, Dict, Generic, List, NewType, Optional, Set, Tuple, TypeVar, Union
5+
6+
import pytest
57

68
import serde
79
from serde.compat import (
810
Literal,
11+
get_generic_arg,
912
is_dict,
1013
is_generic,
1114
is_list,
@@ -25,6 +28,7 @@
2528
from .data import Bool, Float, Int, Pri, PriOpt, Str
2629

2730
T = TypeVar("T")
31+
U = TypeVar("U")
2832

2933

3034
def test_types():
@@ -257,3 +261,19 @@ def test_is_generic():
257261
assert not serde.is_serializable(GenericFoo[List[int]])
258262
assert serde.is_deserializable(GenericFoo)
259263
assert not serde.is_deserializable(GenericFoo[List[int]])
264+
265+
266+
def test_get_generic_arg():
267+
class GenericFoo(Generic[T, U]):
268+
pass
269+
270+
assert get_generic_arg(GenericFoo[int, str], ["T", "U"], ["T", "U"], 0) == int
271+
assert get_generic_arg(GenericFoo[int, str], ["T", "U"], ["T", "U"], 1) == str
272+
assert get_generic_arg(GenericFoo[int, str], ["T", "U"], ["U"], 0) == str
273+
assert get_generic_arg(GenericFoo[int, str], ["T", "U"], ["V"], 0) == Any
274+
275+
with pytest.raises(serde.SerdeError):
276+
get_generic_arg(GenericFoo[int, str], ["T"], ["T"], 0)
277+
278+
with pytest.raises(serde.SerdeError):
279+
get_generic_arg(GenericFoo[int, str], ["T"], ["U"], 0)

0 commit comments

Comments
 (0)