|
17 | 17 | from abc import ABC, abstractmethod |
18 | 18 | from collections import OrderedDict |
19 | 19 | from functools import lru_cache |
20 | | -from types import GenericAlias |
| 20 | +from types import GenericAlias, UnionType |
21 | 21 | from typing import Any, Dict, List, NamedTuple, Optional, Type, cast |
22 | 22 |
|
23 | 23 | import msgpack |
@@ -565,7 +565,14 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): |
565 | 565 | original_type = type(v) |
566 | 566 | if UnionTransformer.is_optional_type(expected_type): |
567 | 567 | expected_type = UnionTransformer.get_sub_type_in_optional(expected_type) |
568 | | - if original_type != expected_type: |
| 568 | + |
| 569 | + if ( |
| 570 | + UnionTransformer.is_union(expected_type) and |
| 571 | + UnionTransformer.in_union(original_type, expected_type) |
| 572 | + ): |
| 573 | + pass |
| 574 | + |
| 575 | + elif original_type != expected_type: |
569 | 576 | raise TypeTransformerFailedError( |
570 | 577 | f"Type of Val '{original_type}' is not an instance of {expected_type}" |
571 | 578 | ) |
@@ -1836,6 +1843,16 @@ class UnionTransformer(AsyncTypeTransformer[T]): |
1836 | 1843 | def __init__(self): |
1837 | 1844 | super().__init__("Typed Union", typing.Union) |
1838 | 1845 |
|
| 1846 | + @staticmethod |
| 1847 | + def is_union(t: Type[Any] | UnionType) -> bool: |
| 1848 | + |
| 1849 | + return _is_union_type(t) |
| 1850 | + |
| 1851 | + @staticmethod |
| 1852 | + def in_union(t: Type[Any], union: types.UnionType) -> bool: |
| 1853 | + |
| 1854 | + return t in typing.get_args(union) |
| 1855 | + |
1839 | 1856 | @staticmethod |
1840 | 1857 | def is_optional_type(t: Type) -> bool: |
1841 | 1858 | return _is_union_type(t) and type(None) in get_args(t) |
|
0 commit comments