|
1 | 1 | import asyncio |
2 | 2 | import atexit |
3 | 3 | import base64 |
| 4 | +import collections.abc |
4 | 5 | import gzip |
5 | 6 | import hashlib |
6 | 7 | import inspect |
|
12 | 13 | import sys |
13 | 14 | import tempfile |
14 | 15 | import traceback |
| 16 | +import types |
15 | 17 | from collections import UserDict |
16 | 18 | from enum import IntEnum |
17 | 19 | from functools import wraps |
|
32 | 34 | Sequence, |
33 | 35 | Tuple, |
34 | 36 | Type, |
35 | | - TypedDict, |
36 | 37 | Union, |
37 | 38 | ) |
38 | 39 |
|
|
68 | 69 | "PathLock", |
69 | 70 | "i2b", |
70 | 71 | "b2i", |
| 72 | + "get_hash_int", |
| 73 | + "iter_weights", |
| 74 | + "get_size", |
71 | 75 | ] |
72 | 76 |
|
73 | 77 |
|
@@ -743,7 +747,7 @@ def iter_weights( |
743 | 747 | yield item |
744 | 748 |
|
745 | 749 |
|
746 | | -def default_dict(typeddict_class: dict, **kwargs): |
| 750 | +def default_dict(typeddict_class: Type[dict], **kwargs): |
747 | 751 | """Initializes a dictionary with default zero values based on a subclass of TypedDict. |
748 | 752 |
|
749 | 753 | >>> class Demo(dict): |
@@ -1704,6 +1708,64 @@ async def __aexit__(self, *_): |
1704 | 1708 | self.lock.release() |
1705 | 1709 |
|
1706 | 1710 |
|
| 1711 | +def get_size(obj, seen=None, iterate_unsafe=False) -> int: |
| 1712 | + """Recursively get size of objects. |
| 1713 | +
|
| 1714 | + Args: |
| 1715 | + obj: object of any type |
| 1716 | + seen (set): set of ids of objects already seen |
| 1717 | + iterate_unsafe (bool, optional): whether to iterate through generators/iterators. Defaults to False. |
| 1718 | +
|
| 1719 | + Returns: |
| 1720 | + int: size of object in bytes |
| 1721 | +
|
| 1722 | + Examples: |
| 1723 | + >>> get_size("") > 0 |
| 1724 | + True |
| 1725 | + >>> get_size([]) > 0 |
| 1726 | + True |
| 1727 | + >>> def gen(): |
| 1728 | + ... for i in range(10): |
| 1729 | + ... yield i |
| 1730 | + >>> g = gen() |
| 1731 | + >>> get_size(g) > 0 |
| 1732 | + True |
| 1733 | + >>> next(g) |
| 1734 | + 0 |
| 1735 | + >>> get_size(g, iterate_unsafe=True) > 0 |
| 1736 | + True |
| 1737 | + >>> try: |
| 1738 | + ... next(g) |
| 1739 | + ... except StopIteration: |
| 1740 | + ... "StopIteration" |
| 1741 | + 'StopIteration' |
| 1742 | + """ |
| 1743 | + size = sys.getsizeof(obj) |
| 1744 | + if seen is None: |
| 1745 | + seen = set() |
| 1746 | + obj_id = id(obj) |
| 1747 | + if obj_id in seen: |
| 1748 | + return 0 |
| 1749 | + seen.add(obj_id) |
| 1750 | + if isinstance(obj, (str, bytes, bytearray)): |
| 1751 | + pass |
| 1752 | + elif isinstance(obj, dict): |
| 1753 | + size += sum([get_size(v, seen, iterate_unsafe) for v in obj.values()]) |
| 1754 | + size += sum([get_size(k, seen, iterate_unsafe) for k in obj.keys()]) |
| 1755 | + elif hasattr(obj, "__dict__"): |
| 1756 | + size += get_size(obj.__dict__, seen, iterate_unsafe) |
| 1757 | + elif isinstance(obj, types.GeneratorType) or isinstance( |
| 1758 | + obj, collections.abc.Iterator |
| 1759 | + ): |
| 1760 | + if iterate_unsafe: |
| 1761 | + # Warning: this will consume the generator/iterator |
| 1762 | + size += sum([get_size(i, seen, iterate_unsafe) for i in obj]) |
| 1763 | + elif hasattr(obj, "__iter__"): |
| 1764 | + # Safe to iterate through containers like lists, tuples, sets, etc. |
| 1765 | + size += sum([get_size(i, seen, iterate_unsafe) for i in obj]) |
| 1766 | + return size |
| 1767 | + |
| 1768 | + |
1707 | 1769 | if __name__ == "__main__": |
1708 | 1770 | __name__ = "morebuiltins.utils" |
1709 | 1771 | import doctest |
|
0 commit comments