Skip to content

Commit 8cd7e26

Browse files
authored
Merge pull request #322 from yukinarit/fix-more-type-errors
Fix more type errors
2 parents 0a597e5 + f0b076b commit 8cd7e26

File tree

4 files changed

+61
-60
lines changed

4 files changed

+61
-60
lines changed

serde/compat.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import typing_extensions
3535
import typing_inspect
36-
from typing_extensions import Type
36+
from typing_extensions import Type, TypeGuard
3737

3838
if sys.version_info[:2] == (3, 7):
3939
Literal = typing_extensions.Literal
@@ -130,7 +130,7 @@ class SerdeSkip(Exception):
130130
"""
131131

132132

133-
def get_origin(typ: Type[Any]) -> Optional[Any]:
133+
def get_origin(typ: Any) -> Optional[Any]:
134134
"""
135135
Provide `get_origin` that works in all python versions.
136136
"""
@@ -288,7 +288,7 @@ def union_args(typ: Union) -> Tuple[Type[Any], ...]:
288288
return tuple(types)
289289

290290

291-
def dataclass_fields(cls: Type[Any]) -> Iterator[dataclasses.Field]:
291+
def dataclass_fields(cls: Type[Any]) -> Iterator[dataclasses.Field]: # type: ignore
292292
raw_fields = dataclasses.fields(cls)
293293

294294
try:
@@ -341,33 +341,33 @@ def recursive(cls: TypeLike) -> None:
341341
lst.add(cls)
342342
elif is_opt(cls):
343343
lst.add(Optional)
344-
arg = type_args(cls)
345-
if arg:
346-
recursive(arg[0])
344+
args = type_args(cls)
345+
if args:
346+
recursive(args[0])
347347
elif is_union(cls):
348348
lst.add(Union)
349349
for arg in type_args(cls):
350350
recursive(arg)
351351
elif is_list(cls) or is_set(cls):
352352
lst.add(List)
353-
arg = type_args(cls)
354-
if arg:
355-
recursive(arg[0])
353+
args = type_args(cls)
354+
if args:
355+
recursive(args[0])
356356
elif is_set(cls):
357357
lst.add(Set)
358-
arg = type_args(cls)
359-
if arg:
360-
recursive(arg[0])
358+
args = type_args(cls)
359+
if args:
360+
recursive(args[0])
361361
elif is_tuple(cls):
362362
lst.add(Tuple)
363363
for arg in type_args(cls):
364364
recursive(arg)
365365
elif is_dict(cls):
366366
lst.add(Dict)
367-
arg = type_args(cls)
368-
if arg and len(arg) >= 2:
369-
recursive(arg[0])
370-
recursive(arg[1])
367+
args = type_args(cls)
368+
if args and len(args) >= 2:
369+
recursive(args[0])
370+
recursive(args[1])
371371
else:
372372
lst.add(cls)
373373

@@ -393,21 +393,21 @@ def recursive(cls: TypeLike) -> None:
393393
for f in dataclass_fields(cls):
394394
recursive(f.type)
395395
elif is_opt(cls):
396-
arg = type_args(cls)
397-
if arg:
398-
recursive(arg[0])
396+
args = type_args(cls)
397+
if args:
398+
recursive(args[0])
399399
elif is_list(cls) or is_set(cls):
400-
arg = type_args(cls)
401-
if arg:
402-
recursive(arg[0])
400+
args = type_args(cls)
401+
if args:
402+
recursive(args[0])
403403
elif is_tuple(cls):
404404
for arg in type_args(cls):
405405
recursive(arg)
406406
elif is_dict(cls):
407-
arg = type_args(cls)
408-
if arg and len(arg) >= 2:
409-
recursive(arg[0])
410-
recursive(arg[1])
407+
args = type_args(cls)
408+
if args and len(args) >= 2:
409+
recursive(args[0])
410+
recursive(args[1])
411411

412412
recursive(cls)
413413
return list(lst)
@@ -433,21 +433,21 @@ def recursive(cls: TypeLike) -> None:
433433
for f in dataclass_fields(cls):
434434
recursive(f.type)
435435
elif is_opt(cls):
436-
arg = type_args(cls)
437-
if arg:
438-
recursive(arg[0])
436+
args = type_args(cls)
437+
if args:
438+
recursive(args[0])
439439
elif is_list(cls) or is_set(cls):
440-
arg = type_args(cls)
441-
if arg:
442-
recursive(arg[0])
440+
args = type_args(cls)
441+
if args:
442+
recursive(args[0])
443443
elif is_tuple(cls):
444444
for arg in type_args(cls):
445445
recursive(arg)
446446
elif is_dict(cls):
447-
arg = type_args(cls)
448-
if arg and len(arg) >= 2:
449-
recursive(arg[0])
450-
recursive(arg[1])
447+
args = type_args(cls)
448+
if args and len(args) >= 2:
449+
recursive(args[0])
450+
recursive(args[1])
451451

452452
recursive(cls)
453453
return list(lst)
@@ -695,7 +695,7 @@ def is_none(typ: Type[Any]) -> bool:
695695
PRIMITIVES = [int, float, bool, str]
696696

697697

698-
def is_enum(typ: Type[Any]) -> bool:
698+
def is_enum(typ: Type[Any]) -> TypeGuard[enum.Enum]:
699699
"""
700700
Test if the type is `enum.Enum`.
701701
"""

serde/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
SETTINGS = dict(debug=False)
6464

6565

66-
def init(debug: bool = False):
66+
def init(debug: bool = False) -> None:
6767
SETTINGS['debug'] = debug
6868

6969

@@ -73,7 +73,7 @@ class SerdeScope:
7373
Container to store types and functions used in code generation context.
7474
"""
7575

76-
cls: Type
76+
cls: Type[Any]
7777
""" The exact class this scope is for (needed to distinguish scopes between inherited classes) """
7878

7979
funcs: Dict[str, Callable] = dataclasses.field(default_factory=dict)
@@ -581,7 +581,7 @@ def conv(f: Field, case: Optional[str] = None) -> str:
581581
return name
582582

583583

584-
def union_func_name(prefix: str, union_args: List[Type]) -> str:
584+
def union_func_name(prefix: str, union_args: List[Type[Any]]) -> str:
585585
"""
586586
Generate a function name that contains all union types
587587

serde/de.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
import typing
1111
from dataclasses import dataclass, is_dataclass
12-
from typing import Any, Callable, Dict, List, Optional, TypeVar
12+
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar
1313

1414
import jinja2
1515
from typing_extensions import Type, dataclass_transform
@@ -87,10 +87,10 @@
8787
__all__ = ['deserialize', 'is_deserializable', 'from_dict', 'from_tuple']
8888

8989
# Interface of Custom deserialize function.
90-
DeserializeFunc = Callable[[Type, Any], Any]
90+
DeserializeFunc = Callable[[Type[Any], Any], Any]
9191

9292

93-
def serde_custom_class_deserializer(cls: Type, datavar, value, custom: DeserializeFunc, default: Callable):
93+
def serde_custom_class_deserializer(cls: Type[Any], datavar, value, custom: DeserializeFunc, default: Callable):
9494
"""
9595
Handle custom deserialization. Use default deserialization logic if it receives `SerdeSkip` exception.
9696
@@ -106,7 +106,7 @@ def serde_custom_class_deserializer(cls: Type, datavar, value, custom: Deseriali
106106
return default()
107107

108108

109-
def default_deserializer(_cls: Type, obj):
109+
def default_deserializer(_cls: Type[Any], obj):
110110
"""
111111
Marker function to tell serde to use the default deserializer. It's used when custom deserializer is specified
112112
at the class but you want to override a field with the default deserializer.
@@ -590,7 +590,8 @@ def data(self, d):
590590
self.datavar = d
591591

592592

593-
defields = functools.partial(fields, DeField)
593+
def defields(cls: Type[Any]) -> List[DeField]:
594+
return fields(DeField, cls)
594595

595596

596597
@dataclass
@@ -884,7 +885,7 @@ def default(self, arg: DeField, code: str) -> str:
884885
return code
885886

886887

887-
def to_arg(f: DeField, index, rename_all: Optional[str] = None) -> DeField:
888+
def to_arg(f: DeField, index: int, rename_all: Optional[str] = None) -> DeField:
888889
f.index = index
889890
f.data = 'data'
890891
f.case = f.case or rename_all
@@ -897,7 +898,7 @@ def to_iter_arg(f: DeField, *args, **kwargs) -> DeField:
897898
return f
898899

899900

900-
def render_from_iter(cls: Type, custom: Optional[DeserializeFunc] = None, type_check: TypeCheck = NoCheck) -> str:
901+
def render_from_iter(cls: Type[Any], custom: Optional[DeserializeFunc] = None, type_check: TypeCheck = NoCheck) -> str:
901902
template = """
902903
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
903904
if reuse_instances is Ellipsis:
@@ -933,7 +934,7 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
933934

934935

935936
def render_from_dict(
936-
cls: Type,
937+
cls: Type[Any],
937938
rename_all: Optional[str] = None,
938939
custom: Optional[DeserializeFunc] = None,
939940
type_check: TypeCheck = NoCheck,
@@ -981,7 +982,7 @@ def {{func}}(cls=cls, maybe_generic=None, data=None,
981982
return res
982983

983984

984-
def render_union_func(cls: Type, union_args: List[Type], tagging: Tagging = DefaultTagging) -> str:
985+
def render_union_func(cls: Type[Any], union_args: List[Type], tagging: Tagging = DefaultTagging) -> str:
985986
template = """
986987
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
987988
errors = []
@@ -1039,7 +1040,7 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
10391040
)
10401041

10411042

1042-
def render_literal_func(cls: Type, literal_args: List[Any], tagging: Tagging = DefaultTagging) -> str:
1043+
def render_literal_func(cls: Type[Any], literal_args: List[Any], tagging: Tagging = DefaultTagging) -> str:
10431044
template = """
10441045
def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_scope.reuse_instances_default}}):
10451046
if data in ({%- for v in literal_args -%}{{v|repr}},{%- endfor -%}):

serde/se.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
import typing
1111
from dataclasses import dataclass, is_dataclass
12-
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar
12+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, TypeVar
1313

1414
import jinja2
1515
from typing_extensions import dataclass_transform
@@ -29,7 +29,6 @@
2929
is_datetime,
3030
is_datetime_instance,
3131
is_dict,
32-
is_ellipsis,
3332
is_enum,
3433
is_generic,
3534
is_list,
@@ -381,14 +380,14 @@ def serializable_to_obj(object):
381380
raise SerdeError(e)
382381

383382

384-
def astuple(v):
383+
def astuple(v: Any) -> Tuple[Any, ...]:
385384
"""
386385
Serialize object into tuple.
387386
"""
388387
return to_tuple(v, reuse_instances=False, convert_sets=False)
389388

390389

391-
def to_tuple(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Any:
390+
def to_tuple(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Tuple[Any, ...]:
392391
"""
393392
Serialize object into tuple.
394393
@@ -411,14 +410,14 @@ def to_tuple(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Any:
411410
return to_obj(o, named=False, reuse_instances=reuse_instances, convert_sets=convert_sets)
412411

413412

414-
def asdict(v: Any) -> Dict[str, Any]:
413+
def asdict(v: Any) -> Dict[Any, Any]:
415414
"""
416415
Serialize object into dictionary.
417416
"""
418417
return to_dict(v, reuse_instances=False, convert_sets=False)
419418

420419

421-
def to_dict(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Any:
420+
def to_dict(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Dict[Any, Any]:
422421
"""
423422
Serialize object into dictionary.
424423
@@ -462,7 +461,7 @@ def varname(self) -> str:
462461
raise SerdeError("Field name is None.")
463462
return self.name
464463

465-
def __getitem__(self, n) -> "SeField":
464+
def __getitem__(self, n: int) -> "SeField":
466465
typ = type_args(self.type)[n]
467466
return SeField(typ, name=None)
468467

@@ -473,7 +472,6 @@ def sefields(cls: Type[Any], serialize_class_var: bool = False) -> Iterator[SeFi
473472
"""
474473
for f in fields(SeField, cls, serialize_class_var=serialize_class_var):
475474
f.parent = SeField(None, "obj") # type: ignore
476-
assert isinstance(f, SeField)
477475
yield f
478476

479477

@@ -725,7 +723,9 @@ def render(self, arg: SeField) -> str:
725723
elif is_any(arg.type) or isinstance(arg.type, TypeVar):
726724
res = f"to_obj({arg.varname}, True, False, False)"
727725
elif is_generic(arg.type):
728-
arg.type = get_origin(arg.type)
726+
origin = get_origin(arg.type)
727+
assert origin
728+
arg.type = origin
729729
res = self.render(arg)
730730
elif is_literal(arg.type):
731731
res = self.literal(arg)
@@ -852,7 +852,7 @@ def literal(self, arg: SeField) -> str:
852852
return f"{arg.varname}"
853853

854854

855-
def enum_value(cls, e):
855+
def enum_value(cls: Any, e: Any) -> Any:
856856
"""
857857
Helper function to get value from enum or enum compatible value.
858858
"""

0 commit comments

Comments
 (0)