Skip to content

Commit a50d1cd

Browse files
committed
Use Adapter[Any] as type of root_tree
1 parent 98466e2 commit a50d1cd

File tree

7 files changed

+84
-38
lines changed

7 files changed

+84
-38
lines changed

tiled/_tests/test_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def test_empty_api_key():
305305
class Dummy:
306306
"Referenced below in test_tree_given_as_method"
307307

308-
def constructor():
308+
@classmethod
309+
def constructor(cls):
309310
return tree
310311

311312

@@ -318,7 +319,8 @@ def test_tree_given_as_method():
318319
},
319320
]
320321
}
321-
Config.model_validate(config)
322+
conf = Config.model_validate(config)
323+
assert conf.merged_trees == tree
322324

323325

324326
tree.include_routers = [APIRouter()]

tiled/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from datetime import timedelta
99
from functools import cached_property
1010
from pathlib import Path
11-
from typing import Annotated, Any, Iterator, Optional, Union
11+
from typing import Annotated, Any, Callable, Iterator, Optional, Union
1212

1313
from pydantic import BaseModel, Field, field_validator, model_validator
1414

15+
from tiled.adapters.core import Adapter
1516
from tiled.authenticators import ProxiedOIDCAuthenticator
1617
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
1718
from tiled.type_aliases import AppTask, TaskMap
@@ -40,7 +41,10 @@ def sub_paths(segments: tuple[str, ...]) -> Iterator[tuple[str, ...]]:
4041

4142

4243
class TreeSpec(BaseModel):
43-
tree_type: Annotated[EntryPointString, Field(alias="tree")]
44+
tree_type: Annotated[
45+
EntryPointString[Union[Adapter[Any], Callable[..., Adapter[Any]]]],
46+
Field(alias="tree"),
47+
]
4448
path: str
4549
args: Optional[dict[str, Any]] = None
4650

@@ -69,13 +73,13 @@ def segments(self) -> tuple[str, ...]:
6973
return tuple(segment for segment in self.path.split("/") if segment)
7074

7175
@cached_property
72-
def tree(self) -> Any:
76+
def tree(self) -> Adapter[Any]:
7377
if callable(self.tree_type):
7478
return self.tree_type(**self.args or {})
7579
return self.tree_type
7680

7781
@property
78-
def tree_entry(self) -> tuple[tuple[str, ...], Any]:
82+
def tree_entry(self) -> tuple[tuple[str, ...], Adapter[Any]]:
7983
return (self.segments, self.tree)
8084

8185
@field_validator("tree_type", mode="before")
@@ -259,7 +263,7 @@ def root_path(self) -> str:
259263
return self.uvicorn.get("root_path") or ""
260264

261265
@cached_property
262-
def merged_trees(self) -> Any: # TODO: update when # 1047 is merged
266+
def merged_trees(self) -> Adapter[Any]:
263267
trees = dict(tree.tree_entry for tree in self.trees)
264268
if list(trees) == [()]:
265269
# Simple case: there is one tree, served at the root path /.
@@ -268,7 +272,8 @@ def merged_trees(self) -> Any: # TODO: update when # 1047 is merged
268272
# There are one or more tree(s) to be served at sub-paths so merge
269273
# them into one root MapAdapter with map path segments to dicts
270274
# containing Adapters at that path.
271-
root_mapping = trees.pop((), {})
275+
# As trees do not overlap, there is no '()' entry in trees so use empty dict
276+
root_mapping = {}
272277
index: dict[tuple[str, ...], dict] = {(): root_mapping}
273278
all_routers = []
274279

tiled/server/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
HTTP_500_INTERNAL_SERVER_ERROR,
3737
)
3838

39+
from tiled.adapters.core import Adapter
40+
3941
from ..access_control.protocols import AccessPolicy
4042
from ..authenticators import ProxiedOIDCAuthenticator
4143
from ..catalog.adapter import WouldDeleteData
@@ -115,7 +117,7 @@ def custom_openapi(app):
115117

116118

117119
def build_app(
118-
tree,
120+
tree: Adapter[Any],
119121
authentication: Optional[Authentication] = None,
120122
server_settings=None,
121123
query_registry: Optional[QueryRegistry] = None,
@@ -132,7 +134,7 @@ def build_app(
132134
133135
Parameters
134136
----------
135-
tree : Tree
137+
tree : Adapter[Any]
136138
authentication: dict, optional
137139
Dict of authentication configuration.
138140
server_settings: dict, optional

tiled/server/dependencies.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import List, Optional
1+
from typing import Any, List, Optional
22

3-
import pydantic_settings
43
from fastapi import HTTPException, Query, Request
54
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE
65

76
from ..access_control.protocols import AccessPolicy
8-
from ..adapters.protocols import AnyAdapter
7+
from ..adapters.core import Adapter
98
from ..structures.core import StructureFamily
109
from ..type_aliases import AccessTags, Scopes
1110
from ..utils import BrokenLink
@@ -14,7 +13,7 @@
1413
from .utils import filter_for_access, record_timing
1514

1615

17-
def get_root_tree(request: Request):
16+
def get_root_tree(request: Request) -> Adapter[Any]:
1817
return request.app.state.root_tree
1918

2019

@@ -24,12 +23,12 @@ async def get_entry(
2423
principal: Optional[Principal],
2524
authn_access_tags: Optional[AccessTags],
2625
authn_scopes: Scopes,
27-
root_tree: pydantic_settings.BaseSettings,
26+
root_tree: Adapter[Any],
2827
session_state: dict,
2928
metrics: dict,
3029
structure_families: Optional[set[StructureFamily]] = None,
3130
access_policy: Optional[AccessPolicy] = None,
32-
) -> AnyAdapter:
31+
) -> Adapter[Any]:
3332
"""
3433
Obtain a node in the tree from its path.
3534

tiled/server/router.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from datetime import datetime, timedelta, timezone
99
from functools import cache, partial
1010
from pathlib import Path
11-
from typing import Callable, List, Optional, TypeVar, Union
11+
from typing import Any, Callable, List, Optional, TypeVar, Union
1212

1313
import anyio
1414
import packaging
15-
import pydantic_settings
1615
from fastapi import (
1716
APIRouter,
1817
Body,
@@ -41,7 +40,7 @@
4140
HTTP_422_UNPROCESSABLE_CONTENT,
4241
)
4342

44-
from tiled.adapters.protocols import AnyAdapter
43+
from tiled.adapters.core import Adapter
4544
from tiled.authenticators import ProxiedOIDCAuthenticator
4645
from tiled.media_type_registration import SerializationRegistry
4746
from tiled.query_registration import QueryRegistry
@@ -683,7 +682,7 @@ async def close_stream(
683682
request: Request,
684683
path: str,
685684
principal: Optional[schemas.Principal] = Depends(get_current_principal),
686-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
685+
root_tree: Adapter[Any] = Depends(get_root_tree),
687686
session_state: dict = Depends(get_session_state),
688687
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
689688
authn_scopes: Scopes = Depends(get_current_scopes),
@@ -862,7 +861,7 @@ async def post_table_partition(
862861
async def table_partition(
863862
request: Request,
864863
partition: int,
865-
entry: AnyAdapter,
864+
entry: Adapter[Any],
866865
column: Optional[List[str]],
867866
format: Optional[str],
868867
filename: Optional[str],
@@ -998,7 +997,7 @@ async def post_table_full(
998997

999998
async def table_full(
1000999
request: Request,
1001-
entry: AnyAdapter,
1000+
entry: Adapter[Any],
10021001
column: Optional[List[str]],
10031002
format: Optional[str],
10041003
filename: Optional[str],
@@ -2461,7 +2460,7 @@ async def get_asset_manifest(
24612460
async def validate_specs(
24622461
specs: List[Spec],
24632462
metadata: dict,
2464-
entry: Optional[AnyAdapter] = None,
2463+
entry: Optional[Adapter[Any]] = None,
24652464
structure_family: Optional[StructureFamily] = None,
24662465
structure: Optional[dict] = None,
24672466
settings: Settings = Depends(get_settings),

tiled/server/zarr.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import json
22
import re
3-
from typing import Optional, Tuple, Union
3+
from typing import Any, Optional, Tuple, Union
44

55
import numcodecs
66
import orjson
7-
import pydantic_settings
87
from fastapi import APIRouter, Depends, HTTPException, Request
98
from starlette.responses import Response
109
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
1110

11+
from tiled.adapters.core import Adapter
12+
1213
from ..structures.core import StructureFamily
1314
from ..type_aliases import AccessTags, Scopes
1415
from ..utils import ensure_awaitable
@@ -58,7 +59,7 @@ async def get_zarr_attrs(
5859
principal: Union[Principal] = Depends(get_current_principal),
5960
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
6061
authn_scopes: Scopes = Depends(get_current_scopes),
61-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
62+
root_tree: Adapter[Any] = Depends(get_root_tree),
6263
session_state: dict = Depends(get_session_state),
6364
):
6465
"Return entry.metadata as Zarr attributes metadata (.zattrs)"
@@ -94,7 +95,7 @@ async def get_zarr_group_metadata(
9495
principal: Union[Principal] = Depends(get_current_principal),
9596
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
9697
authn_scopes: Scopes = Depends(get_current_scopes),
97-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
98+
root_tree: Adapter[Any] = Depends(get_root_tree),
9899
session_state: dict = Depends(get_session_state),
99100
):
100101
await get_entry(
@@ -122,7 +123,7 @@ async def get_zarr_array_metadata(
122123
principal: Union[Principal] = Depends(get_current_principal),
123124
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
124125
authn_scopes: Scopes = Depends(get_current_scopes),
125-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
126+
root_tree: Adapter[Any] = Depends(get_root_tree),
126127
session_state: dict = Depends(get_session_state),
127128
):
128129
entry = await get_entry(
@@ -166,7 +167,7 @@ async def get_zarr_array(
166167
principal: Union[Principal] = Depends(get_current_principal),
167168
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
168169
authn_scopes: Scopes = Depends(get_current_scopes),
169-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
170+
root_tree: Adapter[Any] = Depends(get_root_tree),
170171
session_state: dict = Depends(get_session_state),
171172
):
172173
# If a zarr block is requested, e.g. http://localhost:8000/zarr/v2/array/0.1.2.3,
@@ -285,7 +286,7 @@ async def get_zarr_metadata(
285286
principal: Union[Principal] = Depends(get_current_principal),
286287
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
287288
authn_scopes: Scopes = Depends(get_current_scopes),
288-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
289+
root_tree: Adapter[Any] = Depends(get_root_tree),
289290
session_state: dict = Depends(get_session_state),
290291
):
291292
from zarr.dtype import parse_data_type
@@ -377,7 +378,7 @@ async def get_zarr_array(
377378
principal: Union[Principal] = Depends(get_current_principal),
378379
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
379380
authn_scopes: Scopes = Depends(get_current_scopes),
380-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
381+
root_tree: Adapter[Any] = Depends(get_root_tree),
381382
session_state: dict = Depends(get_session_state),
382383
):
383384
entry = await get_entry(
@@ -458,7 +459,7 @@ async def get_zarr_group(
458459
principal: Union[Principal] = Depends(get_current_principal),
459460
authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags),
460461
authn_scopes: Scopes = Depends(get_current_scopes),
461-
root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree),
462+
root_tree: Adapter[Any] = Depends(get_root_tree),
462463
session_state: dict = Depends(get_session_state),
463464
):
464465
entry = await get_entry(

tiled/type_aliases.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import sys
22

3-
from pydantic import AfterValidator
4-
53
if sys.version_info < (3, 10):
64
EllipsisType = type(Ellipsis)
75
else:
86
from types import EllipsisType
97

108
from typing import (
9+
TYPE_CHECKING,
1110
Annotated,
1211
Any,
1312
Callable,
@@ -17,6 +16,7 @@
1716
Sequence,
1817
Set,
1918
TypedDict,
19+
TypeVar,
2020
Union,
2121
)
2222

@@ -43,10 +43,48 @@ class TaskMap(TypedDict):
4343
shutdown: list[AppTask]
4444

4545

46-
EntryPointString = Annotated[
47-
str,
48-
AfterValidator(import_object),
49-
]
46+
if TYPE_CHECKING:
47+
# Let type checking treat this as just the underlying type
48+
T = TypeVar("T")
49+
EntryPointString = Annotated[T, ...]
50+
else:
51+
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
52+
from pydantic.json_schema import JsonSchemaValue
53+
from pydantic.types import AnyType
54+
from pydantic_core import CoreSchema, core_schema
55+
56+
class EntryPointString:
57+
"""
58+
Version of Pydantic's ImportString that supports importing fields of
59+
imported items, not just top level attributes
60+
61+
A string such as `path.to.module:Type.field` is equivalent to
62+
63+
```
64+
from path.to.module import type
65+
return type.field
66+
```
67+
68+
"""
69+
70+
@classmethod
71+
def __class_getitem__(cls, item: AnyType):
72+
return Annotated[item, cls()]
73+
74+
@classmethod
75+
def __get_pydantic_core_schema__(
76+
cls, source: type[Any], handler: GetCoreSchemaHandler
77+
) -> core_schema.CoreSchema:
78+
return core_schema.no_info_plain_validator_function(function=import_object)
79+
80+
@classmethod
81+
def __get_pydantic_json_schema__(
82+
cls, cs: CoreSchema, handler: GetJsonSchemaHandler
83+
) -> JsonSchemaValue:
84+
return handler(core_schema.str_schema())
85+
86+
def __repr__(self) -> str:
87+
return "EntryPointString"
5088

5189

5290
__all__ = [

0 commit comments

Comments
 (0)