Skip to content

feat: Declare WASM modules in guppy #942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
SumType,
TupleType,
Type,
WasmModuleType,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -314,6 +315,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
type_defn = tuple_type_def
case NoneType():
type_defn = none_type_def
case WasmModuleType() as ty:
type_defn = ty.defn
case _:
return assert_never(ty)

Expand Down
7 changes: 7 additions & 0 deletions guppylang/checker/errors/type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,10 @@ class DynamicIterator(Note):
"since the number of elements yielded by this iterator is not statically "
"known"
)


@dataclass(frozen=True)
class WasmTypeConversionError(Error):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in errors/wasm.py?

title: ClassVar[str] = "Can't convert type to WASM"
span_label: ClassVar[str] = "`{thing}` cannot be converted to WASM"
ty: Type
34 changes: 34 additions & 0 deletions guppylang/checker/errors/wasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import ClassVar

from guppylang.diagnostic import Error
from guppylang.tys.ty import Type


class WasmError(Error):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class WasmError(Error):
@dataclass(frozen=True)
class WasmError(Error):

title: ClassVar[str] = "WASM signature error"


@dataclass(frozen=True)
class FirstArgNotModule(WasmError):
span_label: ClassVar[str] = (
"First argument to WASM function should be a reference to a WASM module"
"Instead, found {ty}"
)
ty: Type


@dataclass(frozen=True)
class UnWasmableType(WasmError):
span_label: ClassVar[str] = (
"WASM function signature contained an unsupported type: {ty}"
)
ty: Type


@dataclass(frozen=True)
class NonFunctionWasmType(WasmError):
span_label: ClassVar[str] = (
"WASM function didn't have a function type, instead found {ty}"
)
ty: Type
96 changes: 94 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@

import guppylang
from guppylang.ast_util import annotate_location
from guppylang.checker.core import Globals
from guppylang.compiler.core import GlobalConstId
from guppylang.definition.common import DefId
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
CustomFunctionDef,
CustomInoutCallCompiler,
DefaultCallChecker,
NotImplementedCallCompiler,
OpCompiler,
RawCustomFunctionDef,
WasmCallChecker,
WasmModuleCallCompiler,
WasmModuleDiscardCompiler,
WasmModuleInitCompiler,
)
from guppylang.definition.extern import RawExternDef
from guppylang.definition.function import (
Expand All @@ -38,7 +45,7 @@
)
from guppylang.definition.struct import RawStructDef
from guppylang.definition.traced import RawTracedFunctionDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef, WasmModule
from guppylang.error import MissingModuleError, pretty_errors
from guppylang.ipython_inspect import (
get_ipython_globals,
Expand All @@ -56,9 +63,16 @@
from guppylang.span import Loc, SourceMap, Span
from guppylang.tracing.object import GuppyDefinition
from guppylang.tys.arg import Argument
from guppylang.tys.builtin import option_type
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
)

S = TypeVar("S")
T = TypeVar("T")
Expand Down Expand Up @@ -99,6 +113,7 @@ class _Guppy:
def __init__(self) -> None:
self._modules = {}
self._sources = SourceMap()
self._next_wasm_context = 0

@overload
def __call__(self, arg: F) -> F: ...
Expand Down Expand Up @@ -582,6 +597,77 @@ def load_pytket(
mod.register_def(defn)
return GuppyDefinition(defn)

def wasm_module(
self, filename: str, filehash: int
) -> Decorator[PyClass, GuppyDefinition]:
Comment on lines +600 to +602
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that your decorator doesn't work with explicit modules.

I think that's fine as long as we merge #983 first

# N.B. Only one module per file and vice-versa
guppy_module = self.get_module()
ctx_id = guppy_module._get_next_wasm_context()
assert guppy_module._instance_func_buffer is None
guppy_module._instance_func_buffer = {}

def dec(cls: PyClass) -> GuppyDefinition:
wasm_module = WasmModule(
DefId.fresh(guppy_module),
cls.__name__,
None,
filename,
filehash,
ctx_id,
)
wasm_module_ty = wasm_module.check_instantiate([], Globals.default(), None)
guppy_module.register_def(wasm_module)
# Add a __call__ to the class
call_method = CustomFunctionDef(
DefId.fresh(guppy_module),
"__new__",
None,
FunctionType([], option_type(wasm_module_ty)),
DefaultCallChecker(),
WasmModuleInitCompiler(wasm_module),
False,
GlobalConstId.fresh(f"{cls.__name__}.__new__"),
True,
)
discard = CustomFunctionDef(
DefId.fresh(guppy_module),
"discard",
None,
FunctionType([FuncInput(wasm_module_ty, InputFlags.Owned)], NoneType()),
DefaultCallChecker(),
WasmModuleDiscardCompiler(),
False,
GlobalConstId.fresh(f"{cls.__name__}.__discard__"),
True,
)

assert guppy_module._instance_func_buffer is not None

guppy_module.register_def(wasm_module)
guppy_module._instance_func_buffer |= {
"__new__": call_method,
"discard": discard,
}

guppy_module._register_buffered_instance_funcs(wasm_module)
return GuppyDefinition(wasm_module)

return dec

def wasm(self, f: PyFunc) -> GuppyDefinition:
guppy_module = self.get_module()
func = RawCustomFunctionDef(
DefId.fresh(guppy_module),
f.__name__,
None,
f,
WasmCallChecker(),
WasmModuleCallCompiler(f.__name__),
True,
)
guppy_module.register_def(func)
return GuppyDefinition(func)


class _GuppyDummy:
"""A dummy class with the same interface as `@guppy` that is used during sphinx
Expand Down Expand Up @@ -638,6 +724,12 @@ def load(self, *args: Any, **kwargs: Any) -> None:
def get_module(self, *args: Any, **kwargs: Any) -> Any:
return GuppyModule("dummy", import_builtins=False)

def wasm_module(self, *args: Any, **kwargs: Any) -> Any:
return lambda cls: cls

def wasm(self, *args: Any, **kwargs: Any) -> Any:
return lambda cls: cls


guppy = cast(_Guppy, _GuppyDummy()) if sphinx_running() else _Guppy()

Expand Down
Loading
Loading