Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Support aliases (TypeVar passthrough) in `get_specialized_type_var_map`.
88 changes: 50 additions & 38 deletions strawberry/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from typing import (
Any,
Callable,
Generic,
Optional,
Protocol,
TypeVar,
Union,
get_args,
get_origin,
)
from typing_extensions import get_args
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_args lives in typing since Python 3.8


from strawberry.utils.typing import is_generic_alias
import strawberry


def in_async_context() -> bool:
Expand Down Expand Up @@ -67,57 +70,66 @@ class IntBarFoo(IntBar, Foo[str]): ...
# {}

get_specialized_type_var_map(Bar)
# {~T: ~T}
# {}

get_specialized_type_var_map(IntBar)
# {~T: int}
# {~T: int, ~K: int}

get_specialized_type_var_map(IntBarSubclass)
# {~T: int}
# {~T: int, ~K: int}

get_specialized_type_var_map(IntBarFoo)
# {~T: int, ~K: str}
```
"""
from strawberry.types.base import has_object_definition

orig_bases = getattr(cls, "__orig_bases__", None)
if orig_bases is None:
# Specialized generic aliases will not have __orig_bases__
if get_origin(cls) is not None and is_generic_alias(cls):
orig_bases = (cls,)
else:
# Not a specialized type
return None

type_var_map = {}

# only get type vars for base generics (ie. Generic[T]) and for strawberry types

orig_bases = [b for b in orig_bases if has_object_definition(b)]
param_args: dict[TypeVar, Union[TypeVar, type]] = {}

for base in orig_bases:
# Recursively get type var map from base classes
if base is not cls:
base_type_var_map = get_specialized_type_var_map(base)
if base_type_var_map is not None:
type_var_map.update(base_type_var_map)
types: list[type] = [cls]
while types:
tp = types.pop(0)
if (origin := get_origin(tp)) is None or origin in (Generic, Protocol):
origin = tp

args = get_args(base)
origin = getattr(base, "__origin__", None)

params = origin and getattr(origin, "__parameters__", None)
if params is None:
params = getattr(base, "__parameters__", None)

if not params:
# only get type vars for base generics (i.e. Generic[T]) and for strawberry types
if not has_object_definition(origin):
continue

type_var_map.update(
{p.__name__: a for p, a in zip(params, args) if not isinstance(a, TypeVar)}
)

return type_var_map
if (type_params := getattr(origin, "__parameters__", None)) is not None:
args = get_args(tp)
if args:
for type_param, arg in zip(type_params, args):
if type_param not in param_args:
param_args[type_param] = arg
else:
for type_param in type_params:
if type_param not in param_args:
param_args[type_param] = strawberry.UNSET

if orig_bases := getattr(origin, "__orig_bases__", None):
types.extend(orig_bases)
if not param_args:
return None

for type_param, arg in list(param_args.items()):
resolved_arg = arg
while (
isinstance(resolved_arg, TypeVar) and resolved_arg is not strawberry.UNSET
):
resolved_arg = (
param_args.get(resolved_arg, strawberry.UNSET)
if resolved_arg is not type_param
else strawberry.UNSET
)

param_args[type_param] = resolved_arg

return {
k.__name__: v
for k, v in reversed(param_args.items())
if v is not strawberry.UNSET and not isinstance(v, TypeVar)
}


__all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"]
16 changes: 16 additions & 0 deletions tests/python_312/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ class BinSubclass(Bin): ...
assert get_specialized_type_var_map(Bin) == {"_T": int}


def test_get_specialized_type_var_map_double_generic_passthrough():
@strawberry.type
class Foo[_T]: ...

@strawberry.type
class Bar[_K](Foo[_K]): ...

@strawberry.type
class Bin(Bar[int]): ...

assert get_specialized_type_var_map(Bin) == {
"_T": int,
"_K": int,
}


def test_get_specialized_type_var_map_multiple_inheritance():
@strawberry.type
class Foo[_T]: ...
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ class BinSubclass(Bin): ...
assert get_specialized_type_var_map(Bin) == {"_T": int}


def test_get_specialized_type_var_map_double_generic_passthrough():
@strawberry.type
class Foo(Generic[_T]): ...

@strawberry.type
class Bar(Foo[_K], Generic[_K]): ...

@strawberry.type
class Bin(Bar[int]): ...

assert get_specialized_type_var_map(Bin) == {
"_T": int,
"_K": int,
}


def test_get_specialized_type_var_map_multiple_inheritance():
@strawberry.type
class Foo(Generic[_T]): ...
Expand Down
Loading