Skip to content

Commit 761aca7

Browse files
authored
Add collections.Counter support (#692)
1 parent c22ac7b commit 761aca7

File tree

10 files changed

+326
-8
lines changed

10 files changed

+326
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Happy coding with pyserde! 🚀
6868
- `list`, `collections.abc.Sequence`, `collections.abc.MutableSequence`, `tuple`
6969
- `set`, `collections.abc.Set`, `collections.abc.MutableSet`
7070
- `dict`, `collections.abc.Mapping`, `collections.abc.MutableMapping`
71-
- [`frozenset`](https://docs.python.org/3/library/stdtypes.html#frozenset), [`defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict), [`deque`](https://docs.python.org/3/library/collections.html#collections.deque)
71+
- [`frozenset`](https://docs.python.org/3/library/stdtypes.html#frozenset), [`defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict), [`deque`](https://docs.python.org/3/library/collections.html#collections.deque), [`Counter`](https://docs.python.org/3/library/collections.html#collections.Counter)
7272
- [`typing.Optional`](https://docs.python.org/3/library/typing.html#typing.Optional)
7373
- [`typing.Union`](https://docs.python.org/3/library/typing.html#typing.Union)
7474
- User defined class with [`@dataclass`](https://docs.python.org/3/library/dataclasses.html)

docs/en/types.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Here is the list of the supported types. See the simple example for each type in
1010
* [`frozenset`](https://docs.python.org/3/library/stdtypes.html#frozenset), [^3]
1111
* [`defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) [^4]
1212
* [`deque`](https://docs.python.org/3/library/collections.html#collections.deque) [^25]
13+
* [`Counter`](https://docs.python.org/3/library/collections.html#collections.Counter) [^26]
1314
* [`typing.Optional`](https://docs.python.org/3/library/typing.html#typing.Optional) [^5]
1415
* [`typing.Union`](https://docs.python.org/3/library/typing.html#typing.Union) [^6] [^7] [^8]
1516
* User defined class with [`@dataclass`](https://docs.python.org/3/library/dataclasses.html) [^9] [^10]
@@ -151,3 +152,5 @@ If you need to use a type which is currently not supported in the standard libra
151152
[^24]: See [examples/type_sqlalchemy.py](https://github.com/yukinarit/pyserde/blob/main/examples/type_sqlalchemy.py)
152153

153154
[^25]: See [examples/deque.py](https://github.com/yukinarit/pyserde/blob/main/examples/deque.py)
155+
156+
[^26]: See [examples/counter.py](https://github.com/yukinarit/pyserde/blob/main/examples/counter.py)

docs/ja/types.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* [`frozenset`](https://docs.python.org/3/library/stdtypes.html#frozenset) [^3]
1111
* [`defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) [^4]
1212
* [`deque`](https://docs.python.org/3/library/collections.html#collections.deque) [^25]
13+
* [`Counter`](https://docs.python.org/3/library/collections.html#collections.Counter) [^26]
1314
* [`typing.Optional`](https://docs.python.org/3/library/typing.html#typing.Optional)[^5]
1415
* [`typing.Union`](https://docs.python.org/3/library/typing.html#typing.Union) [^6] [^7] [^8]
1516
* [`@dataclass`](https://docs.python.org/3/library/dataclasses.html) を用いたユーザ定義クラス [^9] [^10]
@@ -156,3 +157,5 @@ SQLAlchemy宣言的データクラスマッピング統合の実験的サポー
156157
[^24]: [examples/type_sqlalchemy.py](https://github.com/yukinarit/pyserde/blob/main/examples/type_sqlalchemy.py) を参照
157158

158159
[^25]: [examples/deque.py](https://github.com/yukinarit/pyserde/blob/main/examples/deque.py) を参照
160+
161+
[^26]: [examples/counter.py](https://github.com/yukinarit/pyserde/blob/main/examples/counter.py) を参照

examples/counter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from collections import Counter
2+
3+
from serde import serde
4+
from serde.json import from_json, to_json
5+
6+
7+
@serde
8+
class WordCount:
9+
counts: Counter[str]
10+
11+
12+
def main() -> None:
13+
# Create a Counter from a list of words
14+
wc = WordCount(counts=Counter(["apple", "banana", "apple", "cherry", "banana", "apple"]))
15+
print(f"Original: {wc}")
16+
print(f"Into Json: {to_json(wc)}")
17+
18+
# Deserialize from JSON
19+
s = '{"counts": {"apple": 3, "banana": 2, "cherry": 1}}'
20+
print(f"From Json: {from_json(WordCount, s)}")
21+
22+
# Counter methods still work after deserialization
23+
wc2 = from_json(WordCount, s)
24+
print(f"Most common: {wc2.counts.most_common(2)}")
25+
26+
27+
if __name__ == "__main__":
28+
main()

serde/compat.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
import typing
1616
import typing_extensions
17-
from collections import defaultdict, deque
17+
from collections import defaultdict, deque, Counter
1818
from collections.abc import Iterator, Sequence, MutableSequence
1919
from collections.abc import Mapping, MutableMapping, Set, MutableSet
2020
from dataclasses import is_dataclass
@@ -218,6 +218,13 @@ def typename(typ: Any, with_typing_module: bool = False) -> str:
218218
return f"deque[{et}]"
219219
else:
220220
return "deque"
221+
elif is_counter(typ):
222+
args = type_args(typ)
223+
if args:
224+
et = thisfunc(args[0])
225+
return f"Counter[{et}]"
226+
else:
227+
return "Counter"
221228
elif is_tuple(typ):
222229
args = type_args(typ)
223230
if args:
@@ -373,6 +380,11 @@ def recursive(cls: Union[type[Any], Any]) -> None:
373380
args = type_args(cls)
374381
if args:
375382
recursive(args[0])
383+
elif is_counter(cls):
384+
lst.add(Counter)
385+
args = type_args(cls)
386+
if args:
387+
recursive(args[0])
376388
elif is_tuple(cls):
377389
lst.add(tuple)
378390
for arg in type_args(cls):
@@ -419,7 +431,7 @@ def recursive(cls: TypeLike) -> None:
419431
args = type_args(cls)
420432
if args:
421433
recursive(args[0])
422-
elif is_list(cls) or is_set(cls) or is_deque(cls):
434+
elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls):
423435
args = type_args(cls)
424436
if args:
425437
recursive(args[0])
@@ -462,7 +474,7 @@ def recursive(cls: Union[type[Any], Any]) -> None:
462474
args = type_args(cls)
463475
if args:
464476
recursive(args[0])
465-
elif is_list(cls) or is_set(cls) or is_deque(cls):
477+
elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls):
466478
args = type_args(cls)
467479
if args:
468480
recursive(args[0])
@@ -830,6 +842,40 @@ def is_bare_deque(typ: type[Any]) -> bool:
830842
return typ is deque
831843

832844

845+
@cache
846+
def is_counter(typ: type[Any]) -> bool:
847+
"""
848+
Test if the type is `collections.Counter`.
849+
850+
>>> is_counter(Counter[str])
851+
True
852+
>>> is_counter(Counter)
853+
True
854+
>>> is_counter(dict[str, int])
855+
False
856+
"""
857+
try:
858+
return issubclass(get_origin(typ), Counter) # type: ignore
859+
except TypeError:
860+
return typ is Counter
861+
862+
863+
@cache
864+
def is_bare_counter(typ: type[Any]) -> bool:
865+
"""
866+
Test if the type is `collections.Counter` without type args.
867+
868+
>>> is_bare_counter(Counter[str])
869+
False
870+
>>> is_bare_counter(Counter)
871+
True
872+
"""
873+
origin = get_origin(typ)
874+
if origin is Counter:
875+
return not type_args(typ)
876+
return typ is Counter
877+
878+
833879
@cache
834880
def is_none(typ: type[Any]) -> bool:
835881
"""

serde/core.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313

1414
from beartype.door import is_bearable
15-
from collections import deque
15+
from collections import deque, Counter
1616
from collections.abc import Mapping, Sequence, MutableSequence, Set, Callable, Hashable
1717
from typing import (
1818
overload,
@@ -29,12 +29,14 @@
2929
SerdeError,
3030
dataclass_fields,
3131
get_origin,
32+
is_bare_counter,
3233
is_bare_deque,
3334
is_bare_dict,
3435
is_bare_list,
3536
is_bare_set,
3637
is_bare_tuple,
3738
is_class_var,
39+
is_counter,
3840
is_deque,
3941
is_dict,
4042
is_generic,
@@ -374,6 +376,9 @@ def is_instance(obj: Any, typ: Any) -> bool:
374376
return is_set_instance(obj, typ)
375377
elif is_tuple(typ):
376378
return is_tuple_instance(obj, typ)
379+
elif is_counter(typ):
380+
# Counter must be checked before dict since Counter is a subclass of dict
381+
return is_counter_instance(obj, typ)
377382
elif is_dict(typ):
378383
return is_dict_instance(obj, typ)
379384
elif is_deque(typ):
@@ -500,6 +505,16 @@ def is_deque_instance(obj: Any, typ: type[Any]) -> bool:
500505
return is_instance(obj[0], deque_arg)
501506

502507

508+
def is_counter_instance(obj: Any, typ: type[Any]) -> bool:
509+
if not isinstance(obj, Counter):
510+
return False
511+
if len(obj) == 0 or is_bare_counter(typ):
512+
return True
513+
counter_arg = type_args(typ)[0]
514+
# for speed reasons we just check the type of the 1st key
515+
return is_instance(next(iter(obj.keys())), counter_arg)
516+
517+
503518
def is_generic_instance(obj: Any, typ: type[Any]) -> bool:
504519
return is_instance(obj, get_origin(typ))
505520

serde/de.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
get_origin,
3232
get_type_var_names,
3333
is_any,
34+
is_bare_counter,
3435
is_bare_deque,
3536
is_bare_dict,
3637
is_bare_list,
3738
is_bare_set,
3839
is_bare_tuple,
40+
is_counter,
3941
is_datetime,
4042
is_default_dict,
4143
is_deque,
@@ -516,6 +518,11 @@ def deserializable_to_obj(cls: type[T]) -> T:
516518
res = collections.deque(o)
517519
else:
518520
res = collections.deque(thisfunc(type_args(c)[0], e) for e in o)
521+
elif is_counter(c):
522+
if is_bare_counter(c):
523+
res = collections.Counter(o)
524+
else:
525+
res = collections.Counter({thisfunc(type_args(c)[0], k): v for k, v in o.items()})
519526
elif is_tuple(c):
520527
if is_bare_tuple(c) or is_variable_tuple(c):
521528
res = tuple(e for e in o)
@@ -709,7 +716,13 @@ def __getitem__(self, n: int) -> DeField[Any] | InnerField[Any]:
709716
"flatten": self.flatten,
710717
"parent": self.parent,
711718
}
712-
if is_list(self.type) or is_set(self.type) or is_dict(self.type) or is_deque(self.type):
719+
if (
720+
is_list(self.type)
721+
or is_set(self.type)
722+
or is_dict(self.type)
723+
or is_deque(self.type)
724+
or is_counter(self.type)
725+
):
713726
return InnerField(typ, "v", datavar="v", **opts)
714727
elif is_tuple(self.type):
715728
return InnerField(typ, f"{self.data}[{n}]", datavar=f"{self.data}[{n}]", **opts)
@@ -849,6 +862,8 @@ def render(self, arg: DeField[Any]) -> str:
849862
res = self.set(arg)
850863
elif is_deque(arg.type):
851864
res = self.deque(arg)
865+
elif is_counter(arg.type):
866+
res = self.counter(arg)
852867
elif is_dict(arg.type):
853868
res = self.dict(arg)
854869
elif is_tuple(arg.type):
@@ -1020,6 +1035,18 @@ def deque(self, arg: DeField[Any]) -> str:
10201035
else:
10211036
return f"collections.deque({self.render(arg[0])} for v in {arg.data})"
10221037

1038+
def counter(self, arg: DeField[Any]) -> str:
1039+
"""
1040+
Render rvalue for Counter.
1041+
"""
1042+
if is_bare_counter(arg.type):
1043+
return f"collections.Counter({arg.data})"
1044+
else:
1045+
k = arg[0]
1046+
k.name = "k"
1047+
k.datavar = "k"
1048+
return f"collections.Counter({{{self.render(k)}: v for k, v in {arg.data}.items()}})"
1049+
10231050
def tuple(self, arg: DeField[Any]) -> str:
10241051
"""
10251052
Render rvalue for tuple.

serde/se.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import jinja2
1414
from dataclasses import dataclass, is_dataclass
1515
from typing import Any, Generic, Literal, TypeVar
16-
from collections import deque
16+
from collections import deque, Counter
1717
from collections.abc import (
1818
Callable,
1919
Iterable,
@@ -32,13 +32,15 @@
3232
T,
3333
get_origin,
3434
is_any,
35+
is_bare_counter,
3536
is_bare_deque,
3637
is_bare_dict,
3738
is_bare_list,
3839
is_bare_opt,
3940
is_bare_set,
4041
is_bare_tuple,
4142
is_class_var,
43+
is_counter,
4244
is_datetime,
4345
is_datetime_instance,
4446
is_deque,
@@ -419,6 +421,8 @@ def serializable_to_obj(object: Any) -> Any:
419421
return [thisfunc(e) for e in o]
420422
elif isinstance(o, deque):
421423
return [thisfunc(e) for e in o]
424+
elif isinstance(o, Counter):
425+
return dict(o)
422426
elif is_str_serializable_instance(o) or is_datetime_instance(o):
423427
se_cls = o.__class__ if not c or c is Any else c
424428
return CACHE.serialize(
@@ -860,6 +864,8 @@ def render(self, arg: SeField[Any]) -> str:
860864
res = self.set(arg)
861865
elif is_deque(arg.type):
862866
res = self.deque(arg)
867+
elif is_counter(arg.type):
868+
res = self.counter(arg)
863869
elif is_dict(arg.type):
864870
res = self.dict(arg)
865871
elif is_tuple(arg.type):
@@ -994,6 +1000,17 @@ def deque(self, arg: SeField[Any]) -> str:
9941000
earg.name = "v"
9951001
return f"[{self.render(earg)} for v in {arg.varname}]"
9961002

1003+
def counter(self, arg: SeField[Any]) -> str:
1004+
"""
1005+
Render rvalue for Counter.
1006+
"""
1007+
if is_bare_counter(arg.type):
1008+
return f"dict({arg.varname})"
1009+
else:
1010+
karg = arg[0]
1011+
karg.name = "k"
1012+
return f"{{{self.render(karg)}: v for k, v in {arg.varname}.items()}}"
1013+
9971014
def tuple(self, arg: SeField[Any]) -> str:
9981015
"""
9991016
Render rvalue for tuple.

tests/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
import uuid
88
from collections.abc import MutableSequence, MutableSet, Sequence, Set
9-
from collections import defaultdict, deque
9+
from collections import defaultdict, deque, Counter
1010
from typing import (
1111
Any,
1212
Generic,
@@ -136,6 +136,9 @@ def yaml_not_supported(se: Any, de: Any, opt: Any) -> bool:
136136
param(deque(["a", "b"]), deque[str]),
137137
param(deque(), deque[int]),
138138
param(deque([1, "a", 3.0]), deque),
139+
param(Counter({"a": 1, "b": 2}), Counter[str]), # Counter
140+
param(Counter(), Counter[str]),
141+
param(Counter({"a": 1}), Counter),
139142
param(data.Pri(10, "foo", 100.0, True), data.Pri), # dataclass
140143
param(data.Pri(10, "foo", 100.0, True), Optional[data.Pri]),
141144
param(None, Optional[data.Pri], toml_not_supported),

0 commit comments

Comments
 (0)