diff --git a/pyproject.toml b/pyproject.toml index 8bc6c9f..14ae6e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "zenlib" -version = "3.1.4" +version = "3.1.5" authors = [ { name="Desultory", email="dev@pyl.onl" }, ] diff --git a/src/zenlib/types/validated_dataclass.py b/src/zenlib/types/validated_dataclass.py index f693bc2..2d7bb4f 100644 --- a/src/zenlib/types/validated_dataclass.py +++ b/src/zenlib/types/validated_dataclass.py @@ -1,11 +1,11 @@ from dataclasses import dataclass +from typing import ForwardRef, Union, get_args, get_origin, get_type_hints def validatedDataclass(cls): from zenlib.logging import loggify from zenlib.util import merge_class - cls = loggify(dataclass(cls)) base_annotations = {} for base in cls.__mro__: @@ -28,6 +28,9 @@ def _validate_attribute(self, attribute, value): expected_type = self.__class__.__annotations__.get(attribute) if not expected_type: return value # No type hint, so we can't validate it + if get_origin(expected_type) is Union and isinstance(get_args(expected_type)[0], ForwardRef): + expected_type = get_type_hints(self.__class__)[attribute] + if not isinstance(value, expected_type): try: value = expected_type(value)