Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic typed linkedobj #9187

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ class UserCodeStatusCollectionV1(SyncableSyftObject):

# this is empty in the case of l0
status_dict: dict[ServerIdentity, tuple[UserCodeStatus, str]] = {}

user_code_link: LinkedObject
user_code_link: LinkedObject[UserCode]


@serializable()
Expand Down Expand Up @@ -434,7 +433,7 @@ class UserCodeV1(SyncableSyftObject):
user_unique_func_name: str
code_hash: str
signature: inspect.Signature
status_link: LinkedObject | None = None
status_link: LinkedObject[UserCodeStatusCollection] | None = None
input_kwargs: list[str]
submit_time: DateTime | None = None
# tracks if the code calls datasite.something, variable is set during parsing
Expand Down
11 changes: 9 additions & 2 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from typing import ClassVar
from typing import TYPE_CHECKING

# third party
from pydantic import model_validator
Expand All @@ -26,15 +27,21 @@
from ..user.user_roles import GUEST_ROLE_LEVEL


if TYPE_CHECKING:
# relative
from ..code.user_code import UserCode
from ..job.job_stash import Job


@serializable()
class ExecutionOutput(SyncableSyftObject):
__canonical_name__ = "ExecutionOutput"
__version__ = SYFT_OBJECT_VERSION_1

executing_user_verify_key: SyftVerifyKey
user_code_link: LinkedObject
user_code_link: "LinkedObject[UserCode]"
output_ids: list[UID] | dict[str, UID] | None = None
job_link: LinkedObject | None = None
job_link: "LinkedObject[Job] | None" = None
created_at: DateTime = DateTime.now()
input_ids: dict[str, UID] | None = None

Expand Down
10 changes: 6 additions & 4 deletions packages/syft/src/syft/service/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ class ProjectRequest(ProjectEventAddObject):
__canonical_name__ = "ProjectRequest"
__version__ = SYFT_OBJECT_VERSION_1

linked_request: LinkedObject
linked_request: LinkedObject[Request]
allowed_sub_types: list[type] = [ProjectRequestResponse]

@field_validator("linked_request", mode="before")
@classmethod
def _validate_linked_request(cls, v: Any) -> LinkedObject:
def _validate_linked_request(cls, v: Any) -> LinkedObject[Request]:
if isinstance(v, Request):
linked_request = LinkedObject.from_obj(v, server_uid=v.server_uid)
linked_request = LinkedObject[Request].from_obj(v, server_uid=v.server_uid)
linked_request.syft_server_location = v.syft_server_location
return linked_request
elif isinstance(v, LinkedObject):
Expand Down Expand Up @@ -1028,7 +1028,9 @@ def add_request(
self,
request: Request,
) -> SyftSuccess:
linked_request = LinkedObject.from_obj(request, server_uid=request.server_uid)
linked_request = LinkedObject[Request].from_obj(
request, server_uid=request.server_uid
)
request_event = ProjectRequest(linked_request=linked_request)
self.add_event(request_event)
return SyftSuccess(message="Request created successfully")
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/queue/queue_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...types.transforms import TransformContext
from ...types.uid import UID
from ..action.action_permissions import ActionObjectPermission
from ..worker.worker_pool import WorkerPool

__all__ = ["QueueItem"]

Expand Down Expand Up @@ -77,7 +78,7 @@ class QueueItem(SyftObject):
job_id: UID | None = None
worker_settings: WorkerSettings | None = None
has_execute_permissions: bool = False
worker_pool: LinkedObject
worker_pool: LinkedObject[WorkerPool]

def __repr__(self) -> str:
return f"<QueueItem: {self.id}>: {self.status}"
Expand Down
8 changes: 4 additions & 4 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ActionStoreChange(Change):
__canonical_name__ = "ActionStoreChange"
__version__ = SYFT_OBJECT_VERSION_1

linked_obj: LinkedObject
linked_obj: LinkedObject[ActionObject]
apply_permission_type: ActionPermission

__repr_attrs__ = ["linked_obj", "apply_permission_type"]
Expand Down Expand Up @@ -1370,8 +1370,8 @@ class UserCodeStatusChange(Change):
__version__ = SYFT_OBJECT_VERSION_1

value: UserCodeStatus
linked_obj: LinkedObject
linked_user_code: LinkedObject
linked_obj: LinkedObject[UserCodeStatusCollection]
linked_user_code: LinkedObject[UserCode]
nested_solved: bool = False
match_type: bool = True
__repr_attrs__ = [
Expand Down Expand Up @@ -1523,7 +1523,7 @@ def link(self) -> SyftObject | None:
class SyncedUserCodeStatusChange(UserCodeStatusChange):
__canonical_name__ = "SyncedUserCodeStatusChange"
__version__ = SYFT_OBJECT_VERSION_1
linked_obj: LinkedObject | None = None # type: ignore
linked_obj: LinkedObject[UserCodeStatusCollection] | None = None # type: ignore

@property
def approved(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/sync/sync_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class SyncState(SyftObject):
objects: dict[UID, SyncableSyftObject] = {}
dependencies: dict[UID, list[UID]] = {}
created_at: DateTime = Field(default_factory=DateTime.now)
previous_state_link: LinkedObject | None = None
previous_state_link: "LinkedObject[SyncState] | None" = None
permissions: dict[UID, set[str]] = {}
storage_permissions: dict[UID, set[UID]] = {}
ignored_batches: dict[UID, int] = {}
Expand Down
109 changes: 98 additions & 11 deletions packages/syft/src/syft/store/linked_obj.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# stdlib
import logging
from typing import Any
from typing import Generic
from typing import TypeVar
from typing import Union
from typing import get_args

# third party
from typing_extensions import Self
Expand All @@ -15,19 +19,21 @@
from ..types.result import as_result
from ..types.syft_object import SYFT_OBJECT_VERSION_1
from ..types.syft_object import SyftObject
from ..types.syft_object import SyftObjectVersioned
from ..types.uid import UID

T = TypeVar("T", bound=SyftObject)
logger = logging.getLogger(__name__)


@serializable()
class LinkedObject(SyftObject):
class LinkedObject(SyftObjectVersioned, Generic[T]):
__canonical_name__ = "LinkedObject"
__version__ = SYFT_OBJECT_VERSION_1

server_uid: UID
service_type: type[Any]
object_type: type[SyftObject]
object_type: type[T]
object_uid: UID

_resolve_cache: SyftObject | None = None
Expand All @@ -40,6 +46,15 @@ def __str__(self) -> str:
)
return f"{resolved_obj_type.__name__}: {self.object_uid} @ Server {self.server_uid}"

@classmethod
def get_generic_type(cls: type[Self]) -> type[T]:
args = cls.__pydantic_generic_metadata__["args"]
if len(args) != 1:
raise ValueError(
"Cannot infer LinkedObject type, generic argument not provided"
)
return args[0] # type: ignore

@property
def resolve(self) -> SyftObject:
return self._resolve()
Expand Down Expand Up @@ -105,10 +120,10 @@ def update_with_context(
@classmethod
def from_obj(
cls,
obj: SyftObject | type[SyftObject],
obj: T | type[T],
service_type: type[Any] | None = None,
server_uid: UID | None = None,
) -> Self:
) -> "LinkedObject[T]": # type: ignore
if service_type is None:
# relative
from ..service.action.action_object import ActionObject
Expand All @@ -129,7 +144,7 @@ def from_obj(
if server_uid is None:
raise Exception(f"{cls} Requires an object UID")

return LinkedObject(
return LinkedObject[type(obj)]( # type: ignore
server_uid=server_uid,
service_type=service_type,
object_type=type(obj),
Expand All @@ -140,11 +155,11 @@ def from_obj(
@classmethod
def with_context(
cls,
obj: SyftObject,
obj: T,
context: ServerServiceContext,
object_uid: UID | None = None,
service_type: type[Any] | None = None,
) -> Self:
) -> "LinkedObject[T]":
if service_type is None:
# relative
from ..service.service import TYPE_TO_SERVICE
Expand All @@ -160,7 +175,7 @@ def with_context(
raise ValueError(f"context {context}'s server is None")
server_uid = context.server.id

return LinkedObject(
return LinkedObject[type(obj)]( # type: ignore
server_uid=server_uid,
service_type=service_type,
object_type=type(obj),
Expand All @@ -171,13 +186,85 @@ def with_context(
def from_uid(
cls,
object_uid: UID,
object_type: type[SyftObject],
object_type: type[T],
service_type: type[Any],
server_uid: UID,
) -> Self:
return cls(
) -> "LinkedObject[T]":
return cls[object_type]( # type: ignore
server_uid=server_uid,
service_type=service_type,
object_type=object_type,
object_uid=object_uid,
)


def _unwrap_optional(type_: Any) -> Any:
try:
if type_ | None == type_:
args = get_args(type_)
return Union[tuple(arg for arg in args if arg != type(None))] # noqa
return type_
except Exception:
return type_


def _annotation_issubclass(type_: Any, cls: type) -> bool:
try:
return issubclass(type_, cls)
except Exception:
return False


def _resolve_syftobject_forward_refs(raise_errors: bool = False) -> None:
# relative
from ..types.syft_object_registry import SyftObjectRegistry

type_names = [
t.__name__ for t in SyftObjectRegistry.__type_to_canonical_name__.keys()
]
if len(type_names) != len(set(type_names)):
raise ValueError(
"Duplicate names in SyftObjectRegistry, cannot resolve forward references"
)

types_namespace = {
k.__name__: k for k in SyftObjectRegistry.__type_to_canonical_name__.keys()
}
syft_objects = [v for v in types_namespace.values() if issubclass(v, SyftObject)]

for so in syft_objects:
so.model_rebuild(raise_errors=raise_errors, _types_namespace=types_namespace)


def find_unannotated_linked_objects() -> None:
# Utility method to find LinkedObjects that are not annotated with a generic type

# relative
from ..types.syft_object_registry import SyftObjectRegistry

# Need to resolve forward references to find LinkedObjects
_resolve_syftobject_forward_refs()

annotated = []
unannotated = []

for cls in SyftObjectRegistry.__type_to_canonical_name__.keys():
if not issubclass(cls, SyftObject):
continue

for name, field in cls.model_fields.items():
type_ = _unwrap_optional(field.annotation)
if _annotation_issubclass(type_, LinkedObject):
try:
type_.get_generic_type()
annotated.append((cls, name))
except Exception:
unannotated.append((cls, name))

print("Annotated LinkedObjects:")
for cls, name in annotated:
print(f"{cls.__name__}.{name}")

print("\n\nUnannotated LinkedObjects:")
for cls, name in unannotated:
print(f"{cls.__name__}.{name}")
Loading