Skip to content

Commit 3a77152

Browse files
authored
Merge pull request #612 from yukinarit/fix-flatten-default
Fix flatten with default
2 parents f2d048c + d0f924c commit 3a77152

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

serde/de.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from beartype.door import is_bearable
1717
from beartype.roar import BeartypeCallHintParamViolation
1818
from dataclasses import dataclass, is_dataclass
19-
from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal
19+
from typing import overload, TypeVar, Generic, Any, Optional, Union, Literal, Iterator
2020
from typing_extensions import dataclass_transform
2121

2222
from .compat import (
@@ -985,11 +985,31 @@ def literal(self, arg: DeField[Any]) -> str:
985985
)
986986

987987
def default(self, arg: DeField[Any], code: str) -> str:
988-
if arg.alias:
989-
aliases = (f'"{s}"' for s in [arg.name, *arg.alias])
990-
exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])'
988+
"""
989+
Renders supplying default value during deserialization.
990+
"""
991+
992+
def get_aliased_fields(arg: Field[Any]) -> Iterator[str]:
993+
return (f'"{s}"' for s in [arg.name, *arg.alias])
994+
995+
if arg.flatten:
996+
# When a field has the `flatten` attribute, iterate over its dataclass fields.
997+
# This ensures that the code checks keys in the data while considering aliases.
998+
flattened = []
999+
for subarg in defields(arg.type):
1000+
if subarg.alias:
1001+
aliases = get_aliased_fields(subarg)
1002+
flattened.append(f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])')
1003+
else:
1004+
flattened.append(f'"{subarg.name}" in {arg.datavar}')
1005+
exists = " and ".join(flattened)
9911006
else:
992-
exists = f'"{arg.conv_name()}" in {arg.datavar}'
1007+
if arg.alias:
1008+
aliases = get_aliased_fields(arg)
1009+
exists = f'_exists_by_aliases({arg.datavar}, [{",".join(aliases)}])'
1010+
else:
1011+
exists = f'"{arg.conv_name()}" in {arg.datavar}'
1012+
9931013
if has_default(arg):
9941014
return f'({code}) if {exists} else serde_scope.defaults["{arg.name}"]'
9951015
elif has_default_factory(arg):

tests/test_flatten.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,37 @@ class Bar:
8181
@serde
8282
class Foo:
8383
bar: list[Bar] = field(flatten=True)
84+
85+
86+
def test_flatten_default() -> None:
87+
@serde
88+
class Bar:
89+
c: float = field(default=0.0)
90+
d: bool = field(default=False)
91+
92+
@serde
93+
class Foo:
94+
a: int
95+
b: str = field(default="foo")
96+
bar: Bar = field(flatten=True, default_factory=Bar)
97+
98+
f = Foo(a=10, b="b", bar=Bar(c=100.0, d=True))
99+
assert from_json(Foo, to_json(f)) == f
100+
101+
assert from_json(Foo, '{"a": 20}') == Foo(20, "foo", Bar())
102+
103+
104+
def test_flatten_default_alias() -> None:
105+
@serde
106+
class Bar:
107+
a: float = field(default=0.0, alias=["aa"]) # type: ignore
108+
b: bool = field(default=False, alias=["bb"]) # type: ignore
109+
110+
@serde
111+
class Foo:
112+
bar: Bar = field(flatten=True, default_factory=Bar)
113+
114+
f = Foo(bar=Bar(100.0, True))
115+
assert from_json(Foo, to_json(f)) == f
116+
117+
assert from_json(Foo, '{"aa": 20.0, "bb": false}') == Foo(Bar(20.0, False))

0 commit comments

Comments
 (0)