Skip to content

Commit

Permalink
(imp) support aliases in get_specialized_type_var_map
Browse files Browse the repository at this point in the history
  • Loading branch information
alexey-pelykh committed Feb 3, 2025
1 parent 7ba5928 commit 025e667
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 38 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Support aliases (TypeVar passthrough) in `get_specialized_type_var_map`.
82 changes: 44 additions & 38 deletions strawberry/utils/inspect.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import asyncio
import inspect
from collections import OrderedDict
from functools import lru_cache
from itertools import zip_longest
from typing import (
Any,
Callable,
Generic,
Optional,
Protocol,
TypeVar,
Union,
get_origin,
)
from typing_extensions import get_args

from strawberry.utils.typing import is_generic_alias


def in_async_context() -> bool:
# Based on the way django checks if there's an event loop in the current thread
Expand Down Expand Up @@ -67,57 +70,60 @@ class IntBarFoo(IntBar, Foo[str]): ...
# {}
get_specialized_type_var_map(Bar)
# {~T: ~T}
# {}
get_specialized_type_var_map(IntBar)
# {~T: int}
# {~T: int, ~K: int}
get_specialized_type_var_map(IntBarSubclass)
# {~T: int}
# {~T: int, ~K: int}
get_specialized_type_var_map(IntBarFoo)
# {~T: int, ~K: str}
```
"""
from strawberry.types.base import has_object_definition

orig_bases = getattr(cls, "__orig_bases__", None)
if orig_bases is None:
# Specialized generic aliases will not have __orig_bases__
if get_origin(cls) is not None and is_generic_alias(cls):
orig_bases = (cls,)
else:
# Not a specialized type
return None

type_var_map = {}

# only get type vars for base generics (ie. Generic[T]) and for strawberry types
param_args = OrderedDict[TypeVar, Union[None, TypeVar, type]]()

orig_bases = [b for b in orig_bases if has_object_definition(b)]
types: list[type] = [cls]
while types:
tp = types.pop(0)
if (origin := get_origin(tp)) is None or origin in (Generic, Protocol):
origin = tp

for base in orig_bases:
# Recursively get type var map from base classes
if base is not cls:
base_type_var_map = get_specialized_type_var_map(base)
if base_type_var_map is not None:
type_var_map.update(base_type_var_map)

args = get_args(base)
origin = getattr(base, "__origin__", None)

params = origin and getattr(origin, "__parameters__", None)
if params is None:
params = getattr(base, "__parameters__", None)

if not params:
# only get type vars for base generics (i.e. Generic[T]) and for strawberry types
if not has_object_definition(origin):
continue

type_var_map.update(
{p.__name__: a for p, a in zip(params, args) if not isinstance(a, TypeVar)}
)

return type_var_map
if (type_params := getattr(origin, "__parameters__", None)) is not None:
args = get_args(tp)
for type_param, arg in zip_longest(type_params, args):
if type_param not in param_args:
param_args[type_param] = arg

if orig_bases := getattr(origin, "__orig_bases__", None):
types.extend(orig_bases)
if not param_args:
return None

resolve = True
while resolve:
resolve = False
for type_param, arg in list(param_args.items()):
if arg is None or not isinstance(arg, TypeVar):
continue
resolved_arg = param_args.get(arg, None) if arg is not type_param else None
param_args[type_param] = resolved_arg

if resolved_arg:
resolve = True

return {
k.__name__: v
for k, v in reversed(param_args.items())
if v is not None and not isinstance(v, TypeVar)
}


__all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"]
16 changes: 16 additions & 0 deletions tests/python_312/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ class BinSubclass(Bin): ...
assert get_specialized_type_var_map(Bin) == {"_T": int}


def test_get_specialized_type_var_map_double_generic_passthrough():
@strawberry.type
class Foo[_T]: ...

@strawberry.type
class Bar[_K](Foo[_K]): ...

@strawberry.type
class Bin(Bar[int]): ...

assert get_specialized_type_var_map(Bin) == {
"_T": int,
"_K": int,
}


def test_get_specialized_type_var_map_multiple_inheritance():
@strawberry.type
class Foo[_T]: ...
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ class BinSubclass(Bin): ...
assert get_specialized_type_var_map(Bin) == {"_T": int}


def test_get_specialized_type_var_map_double_generic_passthrough():
@strawberry.type
class Foo(Generic[_T]): ...

@strawberry.type
class Bar(Foo[_K], Generic[_K]): ...

@strawberry.type
class Bin(Bar[int]): ...

assert get_specialized_type_var_map(Bin) == {
"_T": int,
"_K": int,
}


def test_get_specialized_type_var_map_multiple_inheritance():
@strawberry.type
class Foo(Generic[_T]): ...
Expand Down

0 comments on commit 025e667

Please sign in to comment.