diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..a6e9fd087d --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Support aliases (TypeVar passthrough) in `get_specialized_type_var_map`. diff --git a/strawberry/utils/inspect.py b/strawberry/utils/inspect.py index 650c53b498..80719427b9 100644 --- a/strawberry/utils/inspect.py +++ b/strawberry/utils/inspect.py @@ -4,13 +4,16 @@ from typing import ( Any, Callable, + Generic, Optional, + Protocol, TypeVar, + Union, + get_args, get_origin, ) -from typing_extensions import get_args -from strawberry.utils.typing import is_generic_alias +import strawberry def in_async_context() -> bool: @@ -67,13 +70,13 @@ class IntBarFoo(IntBar, Foo[str]): ... # {} get_specialized_type_var_map(Bar) - # {~T: ~T} + # {} get_specialized_type_var_map(IntBar) - # {~T: int} + # {~T: int, ~K: int} get_specialized_type_var_map(IntBarSubclass) - # {~T: int} + # {~T: int, ~K: int} get_specialized_type_var_map(IntBarFoo) # {~T: int, ~K: str} @@ -81,43 +84,52 @@ class IntBarFoo(IntBar, Foo[str]): ... """ from strawberry.types.base import has_object_definition - orig_bases = getattr(cls, "__orig_bases__", None) - if orig_bases is None: - # Specialized generic aliases will not have __orig_bases__ - if get_origin(cls) is not None and is_generic_alias(cls): - orig_bases = (cls,) - else: - # Not a specialized type - return None - - type_var_map = {} - - # only get type vars for base generics (ie. Generic[T]) and for strawberry types - - orig_bases = [b for b in orig_bases if has_object_definition(b)] + param_args: dict[TypeVar, Union[TypeVar, type]] = {} - for base in orig_bases: - # Recursively get type var map from base classes - if base is not cls: - base_type_var_map = get_specialized_type_var_map(base) - if base_type_var_map is not None: - type_var_map.update(base_type_var_map) + types: list[type] = [cls] + while types: + tp = types.pop(0) + if (origin := get_origin(tp)) is None or origin in (Generic, Protocol): + origin = tp - args = get_args(base) - origin = getattr(base, "__origin__", None) - - params = origin and getattr(origin, "__parameters__", None) - if params is None: - params = getattr(base, "__parameters__", None) - - if not params: + # only get type vars for base generics (i.e. Generic[T]) and for strawberry types + if not has_object_definition(origin): continue - type_var_map.update( - {p.__name__: a for p, a in zip(params, args) if not isinstance(a, TypeVar)} - ) - - return type_var_map + if (type_params := getattr(origin, "__parameters__", None)) is not None: + args = get_args(tp) + if args: + for type_param, arg in zip(type_params, args): + if type_param not in param_args: + param_args[type_param] = arg + else: + for type_param in type_params: + if type_param not in param_args: + param_args[type_param] = strawberry.UNSET + + if orig_bases := getattr(origin, "__orig_bases__", None): + types.extend(orig_bases) + if not param_args: + return None + + for type_param, arg in list(param_args.items()): + resolved_arg = arg + while ( + isinstance(resolved_arg, TypeVar) and resolved_arg is not strawberry.UNSET + ): + resolved_arg = ( + param_args.get(resolved_arg, strawberry.UNSET) + if resolved_arg is not type_param + else strawberry.UNSET + ) + + param_args[type_param] = resolved_arg + + return { + k.__name__: v + for k, v in reversed(param_args.items()) + if v is not strawberry.UNSET and not isinstance(v, TypeVar) + } __all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"] diff --git a/tests/python_312/test_inspect.py b/tests/python_312/test_inspect.py index e894e36dac..ecc127d4cf 100644 --- a/tests/python_312/test_inspect.py +++ b/tests/python_312/test_inspect.py @@ -91,6 +91,22 @@ class BinSubclass(Bin): ... assert get_specialized_type_var_map(Bin) == {"_T": int} +def test_get_specialized_type_var_map_double_generic_passthrough(): + @strawberry.type + class Foo[_T]: ... + + @strawberry.type + class Bar[_K](Foo[_K]): ... + + @strawberry.type + class Bin(Bar[int]): ... + + assert get_specialized_type_var_map(Bin) == { + "_T": int, + "_K": int, + } + + def test_get_specialized_type_var_map_multiple_inheritance(): @strawberry.type class Foo[_T]: ... diff --git a/tests/utils/test_inspect.py b/tests/utils/test_inspect.py index a4cf9c9048..6fde8e7ca8 100644 --- a/tests/utils/test_inspect.py +++ b/tests/utils/test_inspect.py @@ -94,6 +94,22 @@ class BinSubclass(Bin): ... assert get_specialized_type_var_map(Bin) == {"_T": int} +def test_get_specialized_type_var_map_double_generic_passthrough(): + @strawberry.type + class Foo(Generic[_T]): ... + + @strawberry.type + class Bar(Foo[_K], Generic[_K]): ... + + @strawberry.type + class Bin(Bar[int]): ... + + assert get_specialized_type_var_map(Bin) == { + "_T": int, + "_K": int, + } + + def test_get_specialized_type_var_map_multiple_inheritance(): @strawberry.type class Foo(Generic[_T]): ...