Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 13 additions & 25 deletions scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
from collections.abc import Callable
from pathlib import Path

from puya.log import configure_stdio

sys.path.insert(0, str(Path(__file__).parent.parent / "src" / "puyapy" / "_vendor"))


import attrs
import mypy.build
import mypy.find_sources
import mypy.nodes
from mypy.visitor import NodeVisitor

from puya.log import configure_stdio
from puyapy.parse import _get_mypy_options

SCRIPTS_DIR = Path(__file__).parent
Expand Down Expand Up @@ -103,7 +98,7 @@ class ClassBases:


@attrs.define
class SymbolCollector(NodeVisitor[None]):
class SymbolCollector:
file: mypy.nodes.MypyFile
read_source: Callable[[str], list[str] | None]
all_classes: dict[str, tuple[mypy.nodes.MypyFile, mypy.nodes.ClassDef]]
Expand Down Expand Up @@ -135,7 +130,6 @@ def get_src_from_lines(
lines[0] = lines[0][columns[0] :]
return "\n".join(lines)

@typing.override
def visit_mypy_file(self, o: mypy.nodes.MypyFile) -> None:
for stmt in o.defs:
match stmt:
Expand Down Expand Up @@ -179,7 +173,6 @@ def _get_inlined_class(self, klass: ClassBases) -> str:
)
return "\n".join(src)

@typing.override
def visit_class_def(self, o: mypy.nodes.ClassDef) -> None:
self.all_classes[o.fullname] = self.file, o
class_bases = self._get_bases(o)
Expand All @@ -188,11 +181,9 @@ def visit_class_def(self, o: mypy.nodes.ClassDef) -> None:
else:
self.symbols[o.name] = self.get_src(o)

@typing.override
def visit_func_def(self, o: mypy.nodes.FuncDef) -> None:
self.symbols[o.name] = self.get_src(o)

@typing.override
def visit_overloaded_func_def(self, o: mypy.nodes.OverloadedFuncDef) -> None:
line = o.line
end_line = o.end_line or o.line
Expand All @@ -209,7 +200,6 @@ def visit_overloaded_func_def(self, o: mypy.nodes.OverloadedFuncDef) -> None:

self.symbols[o.name] = src

@typing.override
def visit_assignment_stmt(self, o: mypy.nodes.AssignmentStmt) -> None:
try:
(lvalue,) = o.lvalues
Expand All @@ -224,7 +214,6 @@ def visit_assignment_stmt(self, o: mypy.nodes.AssignmentStmt) -> None:
loc.column = lvalue.end_column
self.symbols[lvalue.name] = self.get_src(loc)

@typing.override
def visit_expression_stmt(self, o: mypy.nodes.ExpressionStmt) -> None:
if isinstance(o.expr, mypy.nodes.StrExpr) and isinstance(
self.last_stmt, mypy.nodes.AssignmentStmt
Expand Down Expand Up @@ -264,7 +253,7 @@ def _get_documented_overload(o: mypy.nodes.OverloadedFuncDef) -> mypy.nodes.Func


@attrs.define
class ImportCollector(NodeVisitor[None]):
class ImportCollector:
collected_imports: dict[str, ModuleImports]

def get_imports(self, module_id: str) -> ModuleImports:
Expand All @@ -274,7 +263,6 @@ def get_imports(self, module_id: str) -> ModuleImports:
imports = self.collected_imports[module_id] = ModuleImports()
return imports

@typing.override
def visit_mypy_file(self, o: mypy.nodes.MypyFile) -> None:
for stmt in o.defs:
match stmt:
Expand All @@ -283,13 +271,11 @@ def visit_mypy_file(self, o: mypy.nodes.MypyFile) -> None:
case mypy.nodes.Import():
self.visit_import(stmt)

@typing.override
def visit_import_from(self, o: mypy.nodes.ImportFrom) -> None:
imports = self.get_imports(o.id)
for name, name_as in o.names:
imports.from_imports[name] = name_as

@typing.override
def visit_import(self, o: mypy.nodes.Import) -> None:
for name, name_as in o.ids:
if name != (name_as or name):
Expand All @@ -300,7 +286,7 @@ def visit_import(self, o: mypy.nodes.Import) -> None:


@attrs.define
class DocStub(NodeVisitor[None]):
class DocStub:
read_source: Callable[[str], list[str] | None]
file: mypy.nodes.MypyFile
modules: dict[str, mypy.nodes.MypyFile]
Expand All @@ -319,7 +305,7 @@ def process_module(cls, manager: mypy.build.BuildManager, module_id: str) -> typ
modules = manager.modules
module: mypy.nodes.MypyFile = modules[module_id]
stub = cls(read_source=read_source, file=module, modules=modules)
module.accept(stub)
stub.visit_mypy_file(module)
stub._remove_inlined_symbols() # noqa: SLF001
return stub

Expand All @@ -334,12 +320,18 @@ def _get_module(self, module_id: str) -> SymbolCollector:
all_classes=self.all_classes,
inlined_protocols=self.inlined_protocols,
)
file.accept(collector)
collector.visit_mypy_file(file)
self._collect_imports(file)
return collector

def _collect_imports(self, o: mypy.nodes.Node) -> None:
o.accept(ImportCollector(self.collected_imports))
import_collector = ImportCollector(self.collected_imports)
if isinstance(o, mypy.nodes.MypyFile):
import_collector.visit_mypy_file(o)
elif isinstance(o, mypy.nodes.ImportFrom):
import_collector.visit_import_from(o)
elif isinstance(o, mypy.nodes.Import):
import_collector.visit_import(o)
self._remove_inlined_symbols()

def _remove_inlined_symbols(self) -> None:
Expand All @@ -358,7 +350,6 @@ def _remove_inlined_symbols(self) -> None:
else:
print(f"Symbol/import collision: from {module} import {name} as {name_as}")

@typing.override
def visit_mypy_file(self, o: mypy.nodes.MypyFile) -> None:
for stmt in o.defs:
match stmt:
Expand All @@ -370,7 +361,6 @@ def visit_mypy_file(self, o: mypy.nodes.MypyFile) -> None:
self.visit_import_from(stmt)
self._add_all_symbols(o.fullname)

@typing.override
def visit_import_from(self, o: mypy.nodes.ImportFrom) -> None:
if not _should_inline_module(o.id):
self._collect_imports(o)
Expand All @@ -386,7 +376,6 @@ def visit_import_from(self, o: mypy.nodes.ImportFrom) -> None:
raise Exception("Aliasing symbols in stubs is not supported")
self.add_symbol(module, name)

@typing.override
def visit_import_all(self, o: mypy.nodes.ImportAll) -> None:
if _should_inline_module(o.id):
self._add_all_symbols(o.id)
Expand All @@ -398,7 +387,6 @@ def _add_all_symbols(self, module_id: str) -> None:
for sym in module.symbols:
self.add_symbol(module, sym)

@typing.override
def visit_import(self, o: mypy.nodes.Import) -> None:
self._collect_imports(o)

Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
TEST_CASES_DIR = VCS_ROOT / "test_cases"
FROM_AWST_DIR = VCS_ROOT / "tests" / "from_awst"
NO_INIT_DIR = VCS_ROOT / "tests" / "no-init"
STUBS_DIR = VCS_ROOT / "stubs" / "algopy-stubs"
69 changes: 24 additions & 45 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,39 @@
import sys
from collections.abc import Mapping

import pytest

from puyapy.awst_build import pytypes
from tests import VCS_ROOT

_STUB_SUFFIX = ".pyi"


def stub_class_names_and_predefined_aliases() -> list[str]:
from mypy import build, find_sources, fscache, nodes

from puyapy.parse import _get_mypy_options

stubs_dir = (VCS_ROOT / "stubs" / "algopy-stubs").resolve()
mypy_options = _get_mypy_options()
mypy_options.python_executable = sys.executable
fs_cache = fscache.FileSystemCache()
mypy_build_sources = find_sources.create_source_list(
paths=[str(stubs_dir)], options=mypy_options, fscache=fs_cache
)
build_result = build.build(sources=mypy_build_sources, options=mypy_options, fscache=fs_cache)
result = set()

algopy_module = build_result.files["algopy"]
modules_to_visit = [algopy_module]
seen_modules = set()
while modules_to_visit:
module = modules_to_visit.pop()
if module in seen_modules:
continue
seen_modules.add(module)
for name, symbol in module.names.items():
if name.startswith("_") or symbol.module_hidden or symbol.kind != nodes.GDEF:
continue
match symbol.node:
case nodes.MypyFile() as new_module:
modules_to_visit.append(new_module)
case nodes.TypeAlias(fullname=alias_name):
result.add(alias_name)
case nodes.TypeInfo(fullname=class_name):
result.add(class_name)
return sorted(result)
from tests.utils.stubs_ast import build_stubs_classes

KNOWN_SYMBOLS_WITHOUT_PYTYPES = [
"algopy.arc4._ABIEncoded",
"algopy.arc4._ABICallProtocolType",
"algopy.arc4._StructMeta",
"algopy.arc4._UIntN",
"algopy.gtxn._GroupTransaction",
"algopy.itxn._InnerTransaction",
"algopy._template_variables._TemplateVarGeneric",
"algopy._transaction._ApplicationProtocol",
"algopy._transaction._AssetConfigProtocol",
"algopy._transaction._AssetFreezeProtocol",
"algopy._transaction._AssetTransferProtocol",
"algopy._transaction._KeyRegistrationProtocol",
"algopy._transaction._PaymentProtocol",
"algopy._transaction._TransactionBaseProtocol",
]


def _stub_class_names_and_predefined_aliases() -> list[str]:
class_nodes = build_stubs_classes()
return [c for c in class_nodes if c not in KNOWN_SYMBOLS_WITHOUT_PYTYPES]


@pytest.fixture(scope="session")
def builtins_registry() -> Mapping[str, pytypes.PyType]:
return pytypes.builtins_registry()


@pytest.mark.parametrize(
"fullname",
stub_class_names_and_predefined_aliases(),
ids=str,
)
@pytest.mark.parametrize("fullname", _stub_class_names_and_predefined_aliases(), ids=str)
def test_stub_class_names_lookup(
builtins_registry: Mapping[str, pytypes.PyType], fullname: str
) -> None:
Expand Down
Loading
Loading