Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tiled/_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def test_empty_api_key():
class Dummy:
"Referenced below in test_tree_given_as_method"

def constructor():
@classmethod
def constructor(cls):
return tree


Expand All @@ -318,7 +319,8 @@ def test_tree_given_as_method():
},
]
}
Config.model_validate(config)
conf = Config.model_validate(config)
assert conf.merged_trees == tree


tree.include_routers = [APIRouter()]
Expand Down
17 changes: 11 additions & 6 deletions tiled/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from datetime import timedelta
from functools import cached_property
from pathlib import Path
from typing import Annotated, Any, Iterator, Optional, Union
from typing import Annotated, Any, Callable, Iterator, Optional, Union

from pydantic import BaseModel, Field, field_validator, model_validator

from tiled.adapters.core import Adapter
from tiled.authenticators import ProxiedOIDCAuthenticator
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
from tiled.type_aliases import AppTask, TaskMap
Expand Down Expand Up @@ -40,7 +41,10 @@ def sub_paths(segments: tuple[str, ...]) -> Iterator[tuple[str, ...]]:


class TreeSpec(BaseModel):
tree_type: Annotated[EntryPointString, Field(alias="tree")]
tree_type: Annotated[
EntryPointString[Union[Adapter[Any], Callable[..., Adapter[Any]]]],
Field(alias="tree"),
]
path: str
args: Optional[dict[str, Any]] = None

Expand Down Expand Up @@ -69,13 +73,13 @@ def segments(self) -> tuple[str, ...]:
return tuple(segment for segment in self.path.split("/") if segment)

@cached_property
def tree(self) -> Any:
def tree(self) -> Adapter[Any]:
if callable(self.tree_type):
return self.tree_type(**self.args or {})
return self.tree_type

@property
def tree_entry(self) -> tuple[tuple[str, ...], Any]:
def tree_entry(self) -> tuple[tuple[str, ...], Adapter[Any]]:
return (self.segments, self.tree)

@field_validator("tree_type", mode="before")
Expand Down Expand Up @@ -259,7 +263,7 @@ def root_path(self) -> str:
return self.uvicorn.get("root_path") or ""

@cached_property
def merged_trees(self) -> Any: # TODO: update when # 1047 is merged
def merged_trees(self) -> Adapter[Any]:
trees = dict(tree.tree_entry for tree in self.trees)
if list(trees) == [()]:
# Simple case: there is one tree, served at the root path /.
Expand All @@ -268,7 +272,8 @@ def merged_trees(self) -> Any: # TODO: update when # 1047 is merged
# There are one or more tree(s) to be served at sub-paths so merge
# them into one root MapAdapter with map path segments to dicts
# containing Adapters at that path.
root_mapping = trees.pop((), {})
# As trees do not overlap, there is no '()' entry in trees so use empty dict
root_mapping = {}
index: dict[tuple[str, ...], dict] = {(): root_mapping}
all_routers = []

Expand Down
6 changes: 4 additions & 2 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from tiled.adapters.core import Adapter

from ..access_control.protocols import AccessPolicy
from ..authenticators import ProxiedOIDCAuthenticator
from ..catalog.adapter import WouldDeleteData
Expand Down Expand Up @@ -115,7 +117,7 @@ def custom_openapi(app):


def build_app(
tree,
tree: Adapter[Any],
authentication: Optional[Authentication] = None,
server_settings=None,
query_registry: Optional[QueryRegistry] = None,
Expand All @@ -132,7 +134,7 @@ def build_app(
Parameters
----------
tree : Tree
tree : Adapter[Any]
authentication: dict, optional
Dict of authentication configuration.
server_settings: dict, optional
Expand Down
11 changes: 5 additions & 6 deletions tiled/server/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import List, Optional
from typing import Any, List, Optional

import pydantic_settings
from fastapi import HTTPException, Query, Request
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE

from ..access_control.protocols import AccessPolicy
from ..adapters.protocols import AnyAdapter
from ..adapters.core import Adapter
from ..structures.core import StructureFamily
from ..type_aliases import AccessTags, Scopes
from ..utils import BrokenLink
Expand All @@ -14,7 +13,7 @@
from .utils import filter_for_access, record_timing


def get_root_tree(request: Request):
def get_root_tree(request: Request) -> Adapter[Any]:
return request.app.state.root_tree


Expand All @@ -24,12 +23,12 @@ async def get_entry(
principal: Optional[Principal],
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
root_tree: pydantic_settings.BaseSettings,
root_tree: Adapter[Any],
session_state: dict,
metrics: dict,
structure_families: Optional[set[StructureFamily]] = None,
access_policy: Optional[AccessPolicy] = None,
) -> AnyAdapter:
) -> Adapter[Any]:
"""
Obtain a node in the tree from its path.
Expand Down
13 changes: 6 additions & 7 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from datetime import datetime, timedelta, timezone
from functools import cache, partial
from pathlib import Path
from typing import Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar, Union

import anyio
import packaging
import pydantic_settings
from fastapi import (
APIRouter,
Body,
Expand Down Expand Up @@ -41,7 +40,7 @@
HTTP_422_UNPROCESSABLE_CONTENT,
)

from tiled.adapters.protocols import AnyAdapter
from tiled.adapters.core import Adapter
from tiled.authenticators import ProxiedOIDCAuthenticator
from tiled.media_type_registration import SerializationRegistry
from tiled.query_registration import QueryRegistry
Expand Down Expand Up @@ -683,7 +682,7 @@ async def close_stream(
request: Request,
path: str,
principal: Optional[schemas.Principal] = Depends(get_current_principal),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
Expand Down Expand Up @@ -862,7 +861,7 @@ async def post_table_partition(
async def table_partition(
request: Request,
partition: int,
entry: AnyAdapter,
entry: Adapter[Any],
column: Optional[List[str]],
format: Optional[str],
filename: Optional[str],
Expand Down Expand Up @@ -998,7 +997,7 @@ async def post_table_full(

async def table_full(
request: Request,
entry: AnyAdapter,
entry: Adapter[Any],
column: Optional[List[str]],
format: Optional[str],
filename: Optional[str],
Expand Down Expand Up @@ -2461,7 +2460,7 @@ async def get_asset_manifest(
async def validate_specs(
specs: List[Spec],
metadata: dict,
entry: Optional[AnyAdapter] = None,
entry: Optional[Adapter[Any]] = None,
structure_family: Optional[StructureFamily] = None,
structure: Optional[dict] = None,
settings: Settings = Depends(get_settings),
Expand Down
19 changes: 10 additions & 9 deletions tiled/server/zarr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import re
from typing import Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import numcodecs
import orjson
import pydantic_settings
from fastapi import APIRouter, Depends, HTTPException, Request
from starlette.responses import Response
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR

from tiled.adapters.core import Adapter

from ..structures.core import StructureFamily
from ..type_aliases import AccessTags, Scopes
from ..utils import ensure_awaitable
Expand Down Expand Up @@ -58,7 +59,7 @@ async def get_zarr_attrs(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
"Return entry.metadata as Zarr attributes metadata (.zattrs)"
Expand Down Expand Up @@ -94,7 +95,7 @@ async def get_zarr_group_metadata(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
await get_entry(
Expand Down Expand Up @@ -122,7 +123,7 @@ async def get_zarr_array_metadata(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
entry = await get_entry(
Expand Down Expand Up @@ -166,7 +167,7 @@ async def get_zarr_array(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
# If a zarr block is requested, e.g. http://localhost:8000/zarr/v2/array/0.1.2.3,
Expand Down Expand Up @@ -285,7 +286,7 @@ async def get_zarr_metadata(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
from zarr.dtype import parse_data_type
Expand Down Expand Up @@ -377,7 +378,7 @@ async def get_zarr_array(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
entry = await get_entry(
Expand Down Expand Up @@ -458,7 +459,7 @@ async def get_zarr_group(
principal: Union[Principal] = Depends(get_current_principal),
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
authn_scopes: Scopes = Depends(get_current_scopes),
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
root_tree: Adapter[Any] = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
):
entry = await get_entry(
Expand Down
50 changes: 44 additions & 6 deletions tiled/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import sys

from pydantic import AfterValidator

if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType

from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Expand All @@ -17,6 +16,7 @@
Sequence,
Set,
TypedDict,
TypeVar,
Union,
)

Expand All @@ -43,10 +43,48 @@ class TaskMap(TypedDict):
shutdown: list[AppTask]


EntryPointString = Annotated[
str,
AfterValidator(import_object),
]
if TYPE_CHECKING:
# Let type checking treat this as just the underlying type
T = TypeVar("T")
EntryPointString = Annotated[T, ...]
else:
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic.types import AnyType
from pydantic_core import CoreSchema, core_schema

class EntryPointString:
"""
Version of Pydantic's ImportString that supports importing fields of
imported items, not just top level attributes
A string such as `path.to.module:Type.field` is equivalent to
```
from path.to.module import type
return type.field
```
"""

@classmethod
def __class_getitem__(cls, item: AnyType):
return Annotated[item, cls()]

@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(function=import_object)

@classmethod
def __get_pydantic_json_schema__(
cls, cs: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())

def __repr__(self) -> str:
return "EntryPointString"


__all__ = [
Expand Down