Skip to content

Commit 44b7973

Browse files
committed
(imp) support aliases in get_specialized_type_var_map
1 parent 7ba5928 commit 44b7973

File tree

4 files changed

+94
-34
lines changed

4 files changed

+94
-34
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Release type: minor
2+
3+
Support aliases (TypeVar passthrough) in `get_specialized_type_var_map`.

strawberry/utils/inspect.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from typing import (
55
Any,
66
Callable,
7+
Generic,
8+
Literal,
79
Optional,
10+
Protocol,
811
TypeVar,
12+
Union,
13+
get_args,
914
get_origin,
1015
)
11-
from typing_extensions import get_args
12-
13-
from strawberry.utils.typing import is_generic_alias
1416

1517

1618
def in_async_context() -> bool:
@@ -67,57 +69,80 @@ class IntBarFoo(IntBar, Foo[str]): ...
6769
# {}
6870
6971
get_specialized_type_var_map(Bar)
70-
# {~T: ~T}
72+
# {}
7173
7274
get_specialized_type_var_map(IntBar)
73-
# {~T: int}
75+
# {~T: int, ~K: int}
7476
7577
get_specialized_type_var_map(IntBarSubclass)
76-
# {~T: int}
78+
# {~T: int, ~K: int}
7779
7880
get_specialized_type_var_map(IntBarFoo)
7981
# {~T: int, ~K: str}
8082
```
8183
"""
8284
from strawberry.types.base import has_object_definition
8385

84-
orig_bases = getattr(cls, "__orig_bases__", None)
85-
if orig_bases is None:
86-
# Specialized generic aliases will not have __orig_bases__
87-
if get_origin(cls) is not None and is_generic_alias(cls):
88-
orig_bases = (cls,)
89-
else:
90-
# Not a specialized type
91-
return None
86+
class Unresolved:
87+
"""Sentinel class.
9288
93-
type_var_map = {}
89+
Until PEP 0661 is accepted, this class is used as a sentinel.
90+
"""
9491

95-
# only get type vars for base generics (ie. Generic[T]) and for strawberry types
92+
def __bool__(self) -> Literal[False]:
93+
return False
9694

97-
orig_bases = [b for b in orig_bases if has_object_definition(b)]
95+
def __repr__(self) -> str:
96+
return "UNRESOLVED"
9897

99-
for base in orig_bases:
100-
# Recursively get type var map from base classes
101-
if base is not cls:
102-
base_type_var_map = get_specialized_type_var_map(base)
103-
if base_type_var_map is not None:
104-
type_var_map.update(base_type_var_map)
98+
def __str__(self) -> str:
99+
return "UNRESOLVED"
105100

106-
args = get_args(base)
107-
origin = getattr(base, "__origin__", None)
101+
UNRESOLVED = Unresolved()
108102

109-
params = origin and getattr(origin, "__parameters__", None)
110-
if params is None:
111-
params = getattr(base, "__parameters__", None)
103+
param_args: dict[TypeVar, Union[Unresolved, TypeVar, type]] = {}
112104

113-
if not params:
114-
continue
105+
types: list[type] = [cls]
106+
while types:
107+
tp = types.pop(0)
108+
if (origin := get_origin(tp)) is None or origin in (Generic, Protocol):
109+
origin = tp
115110

116-
type_var_map.update(
117-
{p.__name__: a for p, a in zip(params, args) if not isinstance(a, TypeVar)}
118-
)
111+
# only get type vars for base generics (i.e. Generic[T]) and for strawberry types
112+
if not has_object_definition(origin):
113+
continue
119114

120-
return type_var_map
115+
if (type_params := getattr(origin, "__parameters__", None)) is not None:
116+
args = get_args(tp)
117+
if not args:
118+
args = [UNRESOLVED] * len(type_params)
119+
for type_param, arg in zip(type_params, args):
120+
if type_param not in param_args:
121+
param_args[type_param] = arg
122+
123+
if orig_bases := getattr(origin, "__orig_bases__", None):
124+
types.extend(orig_bases)
125+
if not param_args:
126+
return None
127+
128+
for type_param, arg in list(param_args.items()):
129+
resolved_arg = arg
130+
while isinstance(resolved_arg, TypeVar) and not isinstance(
131+
resolved_arg, Unresolved
132+
):
133+
resolved_arg = (
134+
param_args.get(resolved_arg, UNRESOLVED)
135+
if resolved_arg is not type_param
136+
else UNRESOLVED
137+
)
138+
139+
param_args[type_param] = resolved_arg
140+
141+
return {
142+
k.__name__: v
143+
for k, v in reversed(param_args.items())
144+
if not isinstance(v, Unresolved) and not isinstance(v, TypeVar)
145+
}
121146

122147

123148
__all__ = ["get_func_args", "get_specialized_type_var_map", "in_async_context"]

tests/python_312/test_inspect.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ class BinSubclass(Bin): ...
9191
assert get_specialized_type_var_map(Bin) == {"_T": int}
9292

9393

94+
def test_get_specialized_type_var_map_double_generic_passthrough():
95+
@strawberry.type
96+
class Foo[_T]: ...
97+
98+
@strawberry.type
99+
class Bar[_K](Foo[_K]): ...
100+
101+
@strawberry.type
102+
class Bin(Bar[int]): ...
103+
104+
assert get_specialized_type_var_map(Bin) == {
105+
"_T": int,
106+
"_K": int,
107+
}
108+
109+
94110
def test_get_specialized_type_var_map_multiple_inheritance():
95111
@strawberry.type
96112
class Foo[_T]: ...

tests/utils/test_inspect.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ class BinSubclass(Bin): ...
9494
assert get_specialized_type_var_map(Bin) == {"_T": int}
9595

9696

97+
def test_get_specialized_type_var_map_double_generic_passthrough():
98+
@strawberry.type
99+
class Foo(Generic[_T]): ...
100+
101+
@strawberry.type
102+
class Bar(Foo[_K], Generic[_K]): ...
103+
104+
@strawberry.type
105+
class Bin(Bar[int]): ...
106+
107+
assert get_specialized_type_var_map(Bin) == {
108+
"_T": int,
109+
"_K": int,
110+
}
111+
112+
97113
def test_get_specialized_type_var_map_multiple_inheritance():
98114
@strawberry.type
99115
class Foo(Generic[_T]): ...

0 commit comments

Comments
 (0)