Skip to content

Commit

Permalink
Resolve names using annotation scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
viccie30 committed Jan 5, 2025
1 parent 4d2c744 commit 1ad204d
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 75 deletions.
85 changes: 57 additions & 28 deletions src/_griffe/agents/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ def inspect_class(self, node: ObjectNode) -> None:
name=node.name,
docstring=self._get_docstring(node),
bases=bases,
type_parameters=TypeParameters(*_convert_type_parameters(node.obj, self.current)),
lineno=lineno,
endlineno=endlineno,
)
class_.type_parameters = TypeParameters(*_convert_type_parameters(node.obj, self.current, class_))
self.current.set_member(node.name, class_)
self.current = class_
self.extensions.call("on_instance", node=node, obj=class_, agent=self)
Expand Down Expand Up @@ -446,18 +446,10 @@ def handle_function(self, node: ObjectNode, labels: set | None = None) -> None:
except Exception: # noqa: BLE001
# so many exceptions can be raised here:
# AttributeError, NameError, RuntimeError, ValueError, TokenError, TypeError
parameters = None
returns = None
signature = None
return_annotation = None
else:
parameters = Parameters(
*[_convert_parameter(parameter, parent=self.current) for parameter in signature.parameters.values()],
)
return_annotation = signature.return_annotation
returns = (
None
if return_annotation is _empty
else _convert_object_to_annotation(return_annotation, parent=self.current)
)
return_annotation = signature.return_annotation # type: ignore[union-attr]

lineno, endlineno = self._get_linenos(node)

Expand All @@ -467,21 +459,41 @@ def handle_function(self, node: ObjectNode, labels: set | None = None) -> None:
obj = Attribute(
name=node.name,
value=None,
annotation=returns,
docstring=self._get_docstring(node),
lineno=lineno,
endlineno=endlineno,
)
if return_annotation is not None:
obj.annotation = (
None
if return_annotation is _empty
else _convert_object_to_annotation(
return_annotation,
parent=self.current,
annotation_scope=self.current if self.current.is_class else None, # type: ignore[arg-type]
)
)
else:
obj = Function(
name=node.name,
parameters=parameters,
returns=returns,
type_parameters=TypeParameters(*_convert_type_parameters(node.obj, self.current)),
docstring=self._get_docstring(node),
lineno=lineno,
endlineno=endlineno,
)
obj.type_parameters = TypeParameters(*_convert_type_parameters(node.obj, self.current, obj))
if signature is not None:
obj.parameters = Parameters(
*[
_convert_parameter(parameter, parent=self.current, annotation_scope=obj)
for parameter in signature.parameters.values()
],
)
obj.returns = (
None
if return_annotation is _empty
else _convert_object_to_annotation(return_annotation, parent=self.current, annotation_scope=obj)
)

obj.labels |= labels
self.current.set_member(node.name, obj)
self.extensions.call("on_instance", node=node, obj=obj, agent=self)
Expand All @@ -503,13 +515,13 @@ def inspect_type_alias(self, node: ObjectNode) -> None:

type_alias = TypeAlias(
name=node.name,
value=_convert_type_to_annotation(node.obj.__value__, self.current),
lineno=lineno,
endlineno=endlineno,
type_parameters=TypeParameters(*_convert_type_parameters(node.obj, self.current)),
docstring=self._get_docstring(node),
parent=self.current,
)
type_alias.value = _convert_type_to_annotation(node.obj.__value__, self.current, type_alias)
type_alias.type_parameters = TypeParameters(*_convert_type_parameters(node.obj, self.current, type_alias))
self.current.set_member(node.name, type_alias)
self.extensions.call("on_instance", node=node, obj=type_alias, agent=self)
self.extensions.call("on_type_alias_instance", node=node, type_alias=type_alias, agent=self)
Expand Down Expand Up @@ -579,10 +591,16 @@ def handle_attribute(self, node: ObjectNode, annotation: str | Expr | None = Non
}


def _convert_parameter(parameter: SignatureParameter, parent: Module | Class) -> Parameter:
def _convert_parameter(
parameter: SignatureParameter,
parent: Module | Class,
annotation_scope: Function | Class | TypeAlias,
) -> Parameter:
name = parameter.name
annotation = (
None if parameter.annotation is _empty else _convert_object_to_annotation(parameter.annotation, parent=parent)
None
if parameter.annotation is _empty
else _convert_object_to_annotation(parameter.annotation, parent=parent, annotation_scope=annotation_scope)
)
kind = _parameter_kind_map[parameter.kind]
if parameter.default is _empty:
Expand All @@ -595,7 +613,11 @@ def _convert_parameter(parameter: SignatureParameter, parent: Module | Class) ->
return Parameter(name, annotation=annotation, kind=kind, default=default)


def _convert_object_to_annotation(obj: Any, parent: Module | Class) -> str | Expr | None:
def _convert_object_to_annotation(
obj: Any,
parent: Module | Class,
annotation_scope: Function | Class | TypeAlias | None,
) -> str | Expr | None:
# even when *we* import future annotations,
# the object from which we get a signature
# can come from modules which did *not* import them,
Expand All @@ -612,7 +634,7 @@ def _convert_object_to_annotation(obj: Any, parent: Module | Class) -> str | Exp
annotation_node = compile(obj, mode="eval", filename="<>", flags=ast.PyCF_ONLY_AST, optimize=2)
except SyntaxError:
return obj
return safe_get_annotation(annotation_node.body, parent=parent) # type: ignore[attr-defined]
return safe_get_annotation(annotation_node.body, parent=parent, annotation_scope=annotation_scope) # type: ignore[attr-defined]


_type_parameter_kind_map = {
Expand All @@ -630,6 +652,7 @@ def _convert_object_to_annotation(obj: Any, parent: Module | Class) -> str | Exp
def _convert_type_parameters(
obj: Any,
parent: Module | Class,
annotation_scope: Function | Class | TypeAlias,
) -> list[TypeParameter]:
obj = unwrap(obj)

Expand All @@ -640,16 +663,17 @@ def _convert_type_parameters(
for type_parameter in obj.__type_params__:
bound = getattr(type_parameter, "__bound__", None)
if bound is not None:
bound = _convert_type_to_annotation(bound, parent=parent)
bound = _convert_type_to_annotation(bound, parent=parent, annotation_scope=annotation_scope)
constraints: list[str | Expr] = [
_convert_type_to_annotation(constraint, parent=parent) # type: ignore[misc]
_convert_type_to_annotation(constraint, parent=parent, annotation_scope=annotation_scope) # type: ignore[misc]
for constraint in getattr(type_parameter, "__constraints__", ())
]

if getattr(type_parameter, "has_default", lambda: False)():
default = _convert_type_to_annotation(
type_parameter.__default__,
parent=parent,
annotation_scope=annotation_scope,
)
else:
default = None
Expand All @@ -667,22 +691,27 @@ def _convert_type_parameters(
return type_parameters


def _convert_type_to_annotation(obj: Any, parent: Module | Class) -> str | Expr | None:
def _convert_type_to_annotation(
obj: Any,
parent: Module | Class,
annotation_scope: Function | Class | TypeAlias,
) -> str | Expr | None:
origin = typing.get_origin(obj)

if origin is None:
return _convert_object_to_annotation(obj, parent=parent)
return _convert_object_to_annotation(obj, parent=parent, annotation_scope=annotation_scope)

args: Sequence[str | Expr | None] = [
_convert_type_to_annotation(arg, parent=parent) for arg in typing.get_args(obj)
_convert_type_to_annotation(arg, parent=parent, annotation_scope=annotation_scope)
for arg in typing.get_args(obj)
]

# YORE: EOL 3.9: Replace block with lines 2-3.
if sys.version_info >= (3, 10):
if origin is types.UnionType:
return functools.reduce(lambda left, right: ExprBinOp(left, "|", right), args) # type: ignore[arg-type]

origin = _convert_type_to_annotation(origin, parent=parent)
origin = _convert_type_to_annotation(origin, parent=parent, annotation_scope=annotation_scope)
if origin is None:
return None

Expand Down
76 changes: 48 additions & 28 deletions src/_griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,22 @@ def _get_docstring(self, node: ast.AST, *, strict: bool = False) -> Docstring |
def _get_type_parameters(
self,
statement: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.TypeAlias,
annotation_scope: Function | Class | TypeAlias,
) -> list[TypeParameter]:
return [
TypeParameter(
type_param.name, # type: ignore[attr-defined]
kind=self._type_parameter_kind_map[type(type_param)],
bound=safe_get_annotation(getattr(type_param, "bound", None), parent=self.current),
default=safe_get_annotation(getattr(type_param, "default_value", None), parent=self.current),
bound=safe_get_annotation(
getattr(type_param, "bound", None),
parent=self.current,
annotation_scope=annotation_scope,
),
default=safe_get_annotation(
getattr(type_param, "default_value", None),
parent=self.current,
annotation_scope=annotation_scope,
),
)
for type_param in statement.type_params
]
Expand All @@ -230,6 +239,7 @@ def _get_type_parameters(
def _get_type_parameters(
self,
_statement: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
_obj: Function | Class | TypeAlias,
) -> list[TypeParameter]:
return []

Expand Down Expand Up @@ -316,20 +326,20 @@ def visit_classdef(self, node: ast.ClassDef) -> None:
else:
lineno = node.lineno

# handle base classes
bases = [safe_get_base_class(base, parent=self.current) for base in node.bases]

class_ = Class(
name=node.name,
lineno=lineno,
endlineno=node.end_lineno,
docstring=self._get_docstring(node),
decorators=decorators,
type_parameters=TypeParameters(*self._get_type_parameters(node)),
bases=bases, # type: ignore[arg-type]
runtime=not self.type_guarded,
)

# handle base classes
class_.bases = [safe_get_base_class(base, parent=self.current, annotation_scope=class_) for base in node.bases] # type: ignore[misc]
class_.type_parameters = TypeParameters(*self._get_type_parameters(node, class_))
class_.labels |= self.decorators_to_labels(decorators)

self.current.set_member(node.name, class_)
self.current = class_
self.extensions.call("on_instance", node=node, obj=class_, agent=self)
Expand Down Expand Up @@ -418,45 +428,48 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
attribute = Attribute(
name=node.name,
value=None,
annotation=safe_get_annotation(node.returns, parent=self.current),
lineno=node.lineno,
endlineno=node.end_lineno,
docstring=self._get_docstring(node),
runtime=not self.type_guarded,
)
attribute.annotation = safe_get_annotation(
node.returns,
parent=self.current,
annotation_scope=self.current if not self.current.is_module else None, # type: ignore[arg-type]
)
attribute.labels |= labels
self.current.set_member(node.name, attribute)
self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
return

function = Function(
name=node.name,
lineno=lineno,
endlineno=node.end_lineno,
decorators=decorators,
docstring=self._get_docstring(node),
runtime=not self.type_guarded,
parent=self.current,
)

# handle parameters
parameters = Parameters(
function.parameters = Parameters(
*[
Parameter(
name,
kind=kind,
annotation=safe_get_annotation(annotation, parent=self.current),
annotation=safe_get_annotation(annotation, parent=self.current, annotation_scope=function),
default=default
if isinstance(default, str)
else safe_get_expression(default, parent=self.current, parse_strings=False),
)
for name, annotation, kind, default in get_parameters(node.args)
],
)

function = Function(
name=node.name,
lineno=lineno,
endlineno=node.end_lineno,
parameters=parameters,
returns=safe_get_annotation(node.returns, parent=self.current),
decorators=decorators,
type_parameters=TypeParameters(*self._get_type_parameters(node)),
docstring=self._get_docstring(node),
runtime=not self.type_guarded,
parent=self.current,
)
function.returns = safe_get_annotation(node.returns, parent=self.current, annotation_scope=function)
function.type_parameters = TypeParameters(*self._get_type_parameters(node, function))

property_function = self.get_base_property(decorators, function)

Expand Down Expand Up @@ -519,22 +532,22 @@ def visit_typealias(self, node: ast.TypeAlias) -> None:

name = node.name.id

value = safe_get_expression(node.value, parent=self.current)

try:
docstring = self._get_docstring(ast_next(node), strict=True)
except (LastNodeError, AttributeError):
docstring = None

type_alias = TypeAlias(
name=name,
value=value,
lineno=node.lineno,
endlineno=node.end_lineno,
type_parameters=TypeParameters(*self._get_type_parameters(node)),
docstring=docstring,
parent=self.current,
)

type_alias.value = safe_get_annotation(node.value, parent=self.current, annotation_scope=type_alias)
type_alias.type_parameters = TypeParameters(*self._get_type_parameters(node, type_alias))

self.current.set_member(name, type_alias)
self.extensions.call("on_instance", node=node, obj=type_alias, agent=self)
self.extensions.call("on_type_alias_instance", node=node, type_alias=type_alias, agent=self)
Expand Down Expand Up @@ -711,7 +724,14 @@ def visit_annassign(self, node: ast.AnnAssign) -> None:
Parameters:
node: The node to visit.
"""
self.handle_attribute(node, safe_get_annotation(node.annotation, parent=self.current))
self.handle_attribute(
node,
safe_get_annotation(
node.annotation,
parent=self.current,
annotation_scope=self.current if not self.current.is_module else None, # type: ignore[arg-type]
),
)

def visit_augassign(self, node: ast.AugAssign) -> None:
"""Visit an augmented assignment node.
Expand Down
6 changes: 6 additions & 0 deletions src/_griffe/docstrings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,15 @@ def parse_docstring_annotation(
):
code = compile(annotation, mode="eval", filename="", flags=PyCF_ONLY_AST, optimize=2)
if code.body: # type: ignore[attr-defined]
annotation_scope = docstring.parent
if annotation_scope is not None and annotation_scope.is_attribute:
annotation_scope = annotation_scope.parent
if annotation_scope is not None and annotation_scope.is_module:
annotation_scope = None
name_or_expr = safe_get_annotation(
code.body, # type: ignore[attr-defined]
parent=docstring.parent, # type: ignore[arg-type]
annotation_scope=annotation_scope, # type: ignore[arg-type]
log_level=log_level,
)
return name_or_expr or annotation
Expand Down
Loading

0 comments on commit 1ad204d

Please sign in to comment.