diff --git a/tiled/_tests/test_config.py b/tiled/_tests/test_config.py index b138580fe..4190d759a 100644 --- a/tiled/_tests/test_config.py +++ b/tiled/_tests/test_config.py @@ -305,7 +305,8 @@ def test_empty_api_key(): class Dummy: "Referenced below in test_tree_given_as_method" - def constructor(): + @classmethod + def constructor(cls): return tree @@ -318,7 +319,8 @@ def test_tree_given_as_method(): }, ] } - Config.model_validate(config) + conf = Config.model_validate(config) + assert conf.merged_trees == tree tree.include_routers = [APIRouter()] diff --git a/tiled/config.py b/tiled/config.py index 2d33b410b..ca15e19de 100644 --- a/tiled/config.py +++ b/tiled/config.py @@ -8,10 +8,11 @@ from datetime import timedelta from functools import cached_property from pathlib import Path -from typing import Annotated, Any, Iterator, Optional, Union +from typing import Annotated, Any, Callable, Iterator, Optional, Union from pydantic import BaseModel, Field, field_validator, model_validator +from tiled.adapters.core import Adapter from tiled.authenticators import ProxiedOIDCAuthenticator from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator from tiled.type_aliases import AppTask, TaskMap @@ -40,7 +41,10 @@ def sub_paths(segments: tuple[str, ...]) -> Iterator[tuple[str, ...]]: class TreeSpec(BaseModel): - tree_type: Annotated[EntryPointString, Field(alias="tree")] + tree_type: Annotated[ + EntryPointString[Union[Adapter[Any], Callable[..., Adapter[Any]]]], + Field(alias="tree"), + ] path: str args: Optional[dict[str, Any]] = None @@ -69,13 +73,13 @@ def segments(self) -> tuple[str, ...]: return tuple(segment for segment in self.path.split("/") if segment) @cached_property - def tree(self) -> Any: + def tree(self) -> Adapter[Any]: if callable(self.tree_type): return self.tree_type(**self.args or {}) return self.tree_type @property - def tree_entry(self) -> tuple[tuple[str, ...], Any]: + def tree_entry(self) -> tuple[tuple[str, ...], Adapter[Any]]: return (self.segments, self.tree) @field_validator("tree_type", mode="before") @@ -259,7 +263,7 @@ def root_path(self) -> str: return self.uvicorn.get("root_path") or "" @cached_property - def merged_trees(self) -> Any: # TODO: update when # 1047 is merged + def merged_trees(self) -> Adapter[Any]: trees = dict(tree.tree_entry for tree in self.trees) if list(trees) == [()]: # Simple case: there is one tree, served at the root path /. @@ -268,7 +272,8 @@ def merged_trees(self) -> Any: # TODO: update when # 1047 is merged # There are one or more tree(s) to be served at sub-paths so merge # them into one root MapAdapter with map path segments to dicts # containing Adapters at that path. - root_mapping = trees.pop((), {}) + # As trees do not overlap, there is no '()' entry in trees so use empty dict + root_mapping = {} index: dict[tuple[str, ...], dict] = {(): root_mapping} all_routers = [] diff --git a/tiled/server/app.py b/tiled/server/app.py index 1cb605d1f..a7f1654ec 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -36,6 +36,8 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) +from tiled.adapters.core import Adapter + from ..access_control.protocols import AccessPolicy from ..authenticators import ProxiedOIDCAuthenticator from ..catalog.adapter import WouldDeleteData @@ -115,7 +117,7 @@ def custom_openapi(app): def build_app( - tree, + tree: Adapter[Any], authentication: Optional[Authentication] = None, server_settings=None, query_registry: Optional[QueryRegistry] = None, @@ -132,7 +134,7 @@ def build_app( Parameters ---------- - tree : Tree + tree : Adapter[Any] authentication: dict, optional Dict of authentication configuration. server_settings: dict, optional diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 81f9b8a53..968681be1 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,11 +1,10 @@ -from typing import List, Optional +from typing import Any, List, Optional -import pydantic_settings from fastapi import HTTPException, Query, Request from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE from ..access_control.protocols import AccessPolicy -from ..adapters.protocols import AnyAdapter +from ..adapters.core import Adapter from ..structures.core import StructureFamily from ..type_aliases import AccessTags, Scopes from ..utils import BrokenLink @@ -14,7 +13,7 @@ from .utils import filter_for_access, record_timing -def get_root_tree(request: Request): +def get_root_tree(request: Request) -> Adapter[Any]: return request.app.state.root_tree @@ -24,12 +23,12 @@ async def get_entry( principal: Optional[Principal], authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, - root_tree: pydantic_settings.BaseSettings, + root_tree: Adapter[Any], session_state: dict, metrics: dict, structure_families: Optional[set[StructureFamily]] = None, access_policy: Optional[AccessPolicy] = None, -) -> AnyAdapter: +) -> Adapter[Any]: """ Obtain a node in the tree from its path. diff --git a/tiled/server/router.py b/tiled/server/router.py index ef0266c4c..c0775bc61 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -8,11 +8,10 @@ from datetime import datetime, timedelta, timezone from functools import cache, partial from pathlib import Path -from typing import Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union import anyio import packaging -import pydantic_settings from fastapi import ( APIRouter, Body, @@ -41,7 +40,7 @@ HTTP_422_UNPROCESSABLE_CONTENT, ) -from tiled.adapters.protocols import AnyAdapter +from tiled.adapters.core import Adapter from tiled.authenticators import ProxiedOIDCAuthenticator from tiled.media_type_registration import SerializationRegistry from tiled.query_registration import QueryRegistry @@ -683,7 +682,7 @@ async def close_stream( request: Request, path: str, principal: Optional[schemas.Principal] = Depends(get_current_principal), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), @@ -862,7 +861,7 @@ async def post_table_partition( async def table_partition( request: Request, partition: int, - entry: AnyAdapter, + entry: Adapter[Any], column: Optional[List[str]], format: Optional[str], filename: Optional[str], @@ -998,7 +997,7 @@ async def post_table_full( async def table_full( request: Request, - entry: AnyAdapter, + entry: Adapter[Any], column: Optional[List[str]], format: Optional[str], filename: Optional[str], @@ -2461,7 +2460,7 @@ async def get_asset_manifest( async def validate_specs( specs: List[Spec], metadata: dict, - entry: Optional[AnyAdapter] = None, + entry: Optional[Adapter[Any]] = None, structure_family: Optional[StructureFamily] = None, structure: Optional[dict] = None, settings: Settings = Depends(get_settings), diff --git a/tiled/server/zarr.py b/tiled/server/zarr.py index 85dcd4f70..caa6b0011 100644 --- a/tiled/server/zarr.py +++ b/tiled/server/zarr.py @@ -1,14 +1,15 @@ import json import re -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import numcodecs import orjson -import pydantic_settings from fastapi import APIRouter, Depends, HTTPException, Request from starlette.responses import Response from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from tiled.adapters.core import Adapter + from ..structures.core import StructureFamily from ..type_aliases import AccessTags, Scopes from ..utils import ensure_awaitable @@ -58,7 +59,7 @@ async def get_zarr_attrs( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): "Return entry.metadata as Zarr attributes metadata (.zattrs)" @@ -94,7 +95,7 @@ async def get_zarr_group_metadata( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): await get_entry( @@ -122,7 +123,7 @@ async def get_zarr_array_metadata( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): entry = await get_entry( @@ -166,7 +167,7 @@ async def get_zarr_array( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): # If a zarr block is requested, e.g. http://localhost:8000/zarr/v2/array/0.1.2.3, @@ -285,7 +286,7 @@ async def get_zarr_metadata( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): from zarr.dtype import parse_data_type @@ -377,7 +378,7 @@ async def get_zarr_array( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): entry = await get_entry( @@ -458,7 +459,7 @@ async def get_zarr_group( principal: Union[Principal] = Depends(get_current_principal), authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), - root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + root_tree: Adapter[Any] = Depends(get_root_tree), session_state: dict = Depends(get_session_state), ): entry = await get_entry( diff --git a/tiled/type_aliases.py b/tiled/type_aliases.py index 0f98e764c..b8f06bcee 100644 --- a/tiled/type_aliases.py +++ b/tiled/type_aliases.py @@ -1,13 +1,12 @@ import sys -from pydantic import AfterValidator - if sys.version_info < (3, 10): EllipsisType = type(Ellipsis) else: from types import EllipsisType from typing import ( + TYPE_CHECKING, Annotated, Any, Callable, @@ -17,6 +16,7 @@ Sequence, Set, TypedDict, + TypeVar, Union, ) @@ -43,10 +43,48 @@ class TaskMap(TypedDict): shutdown: list[AppTask] -EntryPointString = Annotated[ - str, - AfterValidator(import_object), -] +if TYPE_CHECKING: + # Let type checking treat this as just the underlying type + T = TypeVar("T") + EntryPointString = Annotated[T, ...] +else: + from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler + from pydantic.json_schema import JsonSchemaValue + from pydantic.types import AnyType + from pydantic_core import CoreSchema, core_schema + + class EntryPointString: + """ + Version of Pydantic's ImportString that supports importing fields of + imported items, not just top level attributes + + A string such as `path.to.module:Type.field` is equivalent to + + ``` + from path.to.module import type + return type.field + ``` + + """ + + @classmethod + def __class_getitem__(cls, item: AnyType): + return Annotated[item, cls()] + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function(function=import_object) + + @classmethod + def __get_pydantic_json_schema__( + cls, cs: CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.str_schema()) + + def __repr__(self) -> str: + return "EntryPointString" __all__ = [