Skip to content

Commit 3c5600c

Browse files
tpoliawdanielballannmaytan
authored andcommitted
chore(types): Use common base type for all access policy types (bluesky#1044)
* Use common base type for all access policy types * Type access policy methods, and introduce aliases. * Fix broken import, more typing updates * Use Mapping rather than Dict (deprecated) for AccessBlob type --------- Co-authored-by: Dan Allan <[email protected]> Co-authored-by: nmaytan <[email protected]>
1 parent 75b9fd2 commit 3c5600c

File tree

12 files changed

+219
-123
lines changed

12 files changed

+219
-123
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Write the date in place of the "Unreleased" in the case a new version is release
77

88
- Enable Tiled server to accept bearer access tokens for authentication
99

10-
## Unreleased
10+
## v0.2.0 (Unreleased)
1111

1212
### Added
1313

@@ -18,6 +18,10 @@ Write the date in place of the "Unreleased" in the case a new version is release
1818
- Column names in `TableStructure` are explicitly converted to strings.
1919
- Ensure that structural dtype arrays read with `CSVAdapter` have two dimensions, `(n, 1)`.
2020

21+
### Refactored
22+
23+
- Use common base type for all access policy types
24+
2125

2226
## v0.1.6 (2025-09-29)
2327

tiled/_tests/test_protocols.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from pytest_mock import MockFixture
1111

1212
from ..access_control.access_policies import ALL_ACCESS
13+
from ..access_control.protocols import AccessPolicy
1314
from ..access_control.scopes import ALL_SCOPES
1415
from ..adapters.awkward_directory_container import DirectoryContainer
1516
from ..adapters.protocols import (
16-
AccessPolicy,
1717
ArrayAdapter,
1818
AwkwardAdapter,
1919
BaseAdapter,
@@ -28,7 +28,7 @@
2828
from ..structures.core import Spec, StructureFamily
2929
from ..structures.sparse import COOStructure
3030
from ..structures.table import TableStructure
31-
from ..type_aliases import JSON, Filters, Scopes
31+
from ..type_aliases import JSON, AccessBlob, AccessTags, Filters, Scopes
3232

3333

3434
class CustomArrayAdapter:
@@ -379,11 +379,30 @@ def __init__(self, scopes: Optional[Scopes] = None) -> None:
379379
def _get_id(self, principal: Principal) -> None:
380380
return None
381381

382+
async def init_node(
383+
self,
384+
principal: Principal,
385+
authn_access_tags: Optional[AccessTags],
386+
authn_scopes: Scopes,
387+
access_blob: Optional[AccessBlob] = None,
388+
) -> Tuple[bool, Optional[AccessBlob]]:
389+
return (False, access_blob)
390+
391+
async def modify_node(
392+
self,
393+
node: BaseAdapter,
394+
principal: Principal,
395+
authn_access_tags: Optional[AccessTags],
396+
authn_scopes: Scopes,
397+
access_blob: Optional[AccessBlob] = None,
398+
) -> Tuple[bool, Optional[AccessBlob]]:
399+
return (False, access_blob)
400+
382401
async def allowed_scopes(
383402
self,
384403
node: BaseAdapter,
385404
principal: Principal,
386-
authn_access_tags: Optional[Set[str]],
405+
authn_access_tags: Optional[AccessTags],
387406
authn_scopes: Scopes,
388407
) -> Scopes:
389408
allowed = self.scopes
@@ -394,7 +413,7 @@ async def filters(
394413
self,
395414
node: BaseAdapter,
396415
principal: Principal,
397-
authn_access_tags: Optional[Set[str]],
416+
authn_access_tags: Optional[AccessTags],
398417
authn_scopes: Scopes,
399418
scopes: Scopes,
400419
) -> Filters:
@@ -407,7 +426,7 @@ async def accesspolicy_protocol_functions(
407426
policy: AccessPolicy,
408427
node: BaseAdapter,
409428
principal: Principal,
410-
authn_access_tags: Optional[Set[str]],
429+
authn_access_tags: Optional[AccessTags],
411430
authn_scopes: Scopes,
412431
scopes: Scopes,
413432
) -> None:

tiled/access_control/access_policies.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import logging
22
import os
3+
from typing import Optional, Tuple
34

5+
from ..adapters.protocols import BaseAdapter
46
from ..queries import AccessBlobFilter
7+
from ..server.schemas import Principal
8+
from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes
59
from ..utils import Sentinel, import_object
10+
from .protocols import AccessPolicy
611
from .scopes import ALL_SCOPES, PUBLIC_SCOPES
712

813
ALL_ACCESS = Sentinel("ALL_ACCESS")
@@ -20,17 +25,42 @@
2025
logger.setLevel(log_level.upper())
2126

2227

23-
class DummyAccessPolicy:
28+
class DummyAccessPolicy(AccessPolicy):
2429
"Impose no access restrictions."
2530

26-
async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes):
31+
async def init_node(
32+
self,
33+
principal: Principal,
34+
authn_access_tags: Optional[AccessTags],
35+
authn_scopes: Scopes,
36+
access_blob: Optional[AccessBlob] = None,
37+
) -> Tuple[bool, AccessBlob]:
38+
"Do nothing; there is no persistent state to initialize."
39+
return (False, access_blob)
40+
41+
async def allowed_scopes(
42+
self,
43+
node: BaseAdapter,
44+
principal: Principal,
45+
authn_access_tags: Optional[AccessTags],
46+
authn_scopes: Scopes,
47+
) -> Scopes:
48+
"Always allow all scopes."
2749
return ALL_SCOPES
2850

29-
async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes):
51+
async def filters(
52+
self,
53+
node: BaseAdapter,
54+
principal: Principal,
55+
authn_access_tags: Optional[AccessTags],
56+
authn_scopes: Scopes,
57+
scopes: Scopes,
58+
) -> Filters:
59+
"Always impose no filtering on results."
3060
return []
3161

3262

33-
class TagBasedAccessPolicy:
63+
class TagBasedAccessPolicy(AccessPolicy):
3464
def __init__(
3565
self,
3666
*,
@@ -73,8 +103,12 @@ def _is_admin(self, authn_scopes):
73103
return False
74104

75105
async def init_node(
76-
self, principal, authn_access_tags, authn_scopes, access_blob=None
77-
):
106+
self,
107+
principal: Principal,
108+
authn_access_tags: Optional[AccessTags],
109+
authn_scopes: Scopes,
110+
access_blob: Optional[AccessBlob] = None,
111+
) -> Tuple[bool, AccessBlob]:
78112
if principal.type == "service":
79113
identifier = str(principal.uuid)
80114
else:
@@ -156,8 +190,13 @@ async def init_node(
156190
return access_blob_modified, access_blob_from_policy
157191

158192
async def modify_node(
159-
self, node, principal, authn_access_tags, authn_scopes, access_blob
160-
):
193+
self,
194+
node: BaseAdapter,
195+
principal: Principal,
196+
authn_access_tags: Optional[AccessTags],
197+
authn_scopes: Scopes,
198+
access_blob: Optional[AccessBlob],
199+
) -> Tuple[bool, AccessBlob]:
161200
if principal.type == "service":
162201
identifier = str(principal.uuid)
163202
else:
@@ -278,7 +317,13 @@ async def modify_node(
278317
# modified means the blob to-be-used was changed in comparison to the user input
279318
return access_blob_modified, access_blob_from_policy
280319

281-
async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes):
320+
async def allowed_scopes(
321+
self,
322+
node: BaseAdapter,
323+
principal: Principal,
324+
authn_access_tags: Optional[AccessTags],
325+
authn_scopes: Scopes,
326+
) -> Scopes:
282327
# If this is being called, filter_for_access has let us get this far.
283328
# However, filters and allowed_scopes should always be implemented to
284329
# give answers consistent with each other.
@@ -317,7 +362,14 @@ async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes)
317362

318363
return allowed
319364

320-
async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes):
365+
async def filters(
366+
self,
367+
node: BaseAdapter,
368+
principal: Principal,
369+
authn_access_tags: Optional[AccessTags],
370+
authn_scopes: Scopes,
371+
scopes: Scopes,
372+
) -> Filters:
321373
queries = []
322374
query_filter = AccessBlobFilter
323375

tiled/access_control/protocols.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Tuple
3+
4+
from ..adapters.protocols import BaseAdapter
5+
from ..server.schemas import Principal
6+
from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes
7+
8+
9+
class AccessPolicy(ABC):
10+
@abstractmethod
11+
async def init_node(
12+
self,
13+
principal: Principal,
14+
authn_access_tags: Optional[AccessTags],
15+
authn_scopes: Scopes,
16+
access_blob: Optional[AccessBlob] = None,
17+
) -> Tuple[bool, Optional[AccessBlob]]:
18+
pass
19+
20+
async def modify_node(
21+
self,
22+
node: BaseAdapter,
23+
principal: Principal,
24+
authn_access_tags: Optional[AccessTags],
25+
authn_scopes: Scopes,
26+
access_blob: Optional[AccessBlob],
27+
) -> Tuple[bool, Optional[AccessBlob]]:
28+
return (False, access_blob)
29+
30+
@abstractmethod
31+
async def allowed_scopes(
32+
self,
33+
node: BaseAdapter,
34+
principal: Principal,
35+
authn_access_tags: Optional[AccessTags],
36+
authn_scopes: Scopes,
37+
) -> Scopes:
38+
pass
39+
40+
@abstractmethod
41+
async def filters(
42+
self,
43+
node: BaseAdapter,
44+
principal: Principal,
45+
authn_access_tags: Optional[AccessTags],
46+
authn_scopes: Scopes,
47+
scopes: Scopes,
48+
) -> Filters:
49+
pass

tiled/adapters/protocols.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from numpy.typing import NDArray
99

1010
from ..ndslice import NDSlice
11-
from ..server.schemas import Principal
1211
from ..storage import Storage
1312
from ..structures.array import ArrayStructure
1413
from ..structures.awkward import AwkwardStructure
1514
from ..structures.core import Spec, StructureFamily
1615
from ..structures.sparse import SparseStructure
1716
from ..structures.table import TableStructure
18-
from ..type_aliases import JSON, Filters, Scopes
17+
from ..type_aliases import JSON
1918
from .awkward_directory_container import DirectoryContainer
2019

2120

@@ -130,26 +129,3 @@ def __getitem__(self, key: str) -> ArrayAdapter:
130129
AnyAdapter = Union[
131130
ArrayAdapter, AwkwardAdapter, ContainerAdapter, SparseAdapter, TableAdapter
132131
]
133-
134-
135-
class AccessPolicy(Protocol):
136-
@abstractmethod
137-
async def allowed_scopes(
138-
self,
139-
node: BaseAdapter,
140-
principal: Principal,
141-
authn_access_tags: Optional[Set[str]],
142-
authn_scopes: Scopes,
143-
) -> Scopes:
144-
pass
145-
146-
@abstractmethod
147-
async def filters(
148-
self,
149-
node: BaseAdapter,
150-
principal: Principal,
151-
authn_access_tags: Optional[Set[str]],
152-
authn_scopes: Scopes,
153-
scopes: Scopes,
154-
) -> Filters:
155-
pass

tiled/server/app.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,8 @@
3636
HTTP_500_INTERNAL_SERVER_ERROR,
3737
)
3838

39-
from tiled.authenticators import ProxiedOIDCAuthenticator
40-
from tiled.query_registration import QueryRegistry, default_query_registry
41-
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
42-
from tiled.type_aliases import AppTask, TaskMap
43-
39+
from ..access_control.protocols import AccessPolicy
40+
from ..authenticators import ProxiedOIDCAuthenticator
4441
from ..catalog.adapter import WouldDeleteData
4542
from ..config import (
4643
Authentication,
@@ -57,10 +54,13 @@
5754
default_deserialization_registry,
5855
default_serialization_registry,
5956
)
57+
from ..query_registration import QueryRegistry, default_query_registry
58+
from ..type_aliases import AppTask, TaskMap
6059
from ..utils import SHARE_TILED_PATH, Conflicts, UnsupportedQueryType
6160
from ..validation_registration import ValidationRegistry, default_validation_registry
6261
from .authentication import move_api_key
6362
from .compression import CompressionMiddleware
63+
from .protocols import ExternalAuthenticator, InternalAuthenticator
6464
from .router import get_metrics_router, get_router
6565
from .settings import Settings, get_settings
6666
from .utils import API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, get_root_url, record_timing
@@ -125,7 +125,7 @@ def build_app(
125125
validation_registry: Optional[ValidationRegistry] = None,
126126
tasks: Optional[dict[str, list[AppTask]]] = None,
127127
scalable=False,
128-
access_policy=None,
128+
access_policy: Optional[AccessPolicy] = None,
129129
):
130130
"""
131131
Serve a Tree
@@ -137,7 +137,7 @@ def build_app(
137137
Dict of authentication configuration.
138138
server_settings: dict, optional
139139
Dict of other server configuration.
140-
access_policy:
140+
access_policy: AccessPolicy, optional
141141
AccessPolicy object encoding rules for which users can see which entries.
142142
"""
143143
authentication = authentication or Authentication()

tiled/server/authentication.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from datetime import datetime, timedelta, timezone
66
from pathlib import Path
7-
from typing import Annotated, Any, Callable, List, Optional, Sequence, Set
7+
from typing import Annotated, Any, Callable, List, Optional, Sequence
88

99
from fastapi import (
1010
APIRouter,
@@ -61,6 +61,7 @@
6161
lookup_valid_pending_session_by_user_code,
6262
lookup_valid_session,
6363
)
64+
from ..type_aliases import AccessTags
6465
from ..utils import SHARE_TILED_PATH, SingleUserPrincipal
6566
from . import schemas
6667
from .connection_pool import get_database_session_factory
@@ -240,7 +241,7 @@ async def get_session_state(decoded_access_token=Depends(get_decoded_access_toke
240241

241242
async def get_access_tags_from_api_key(
242243
api_key: str, authenticated: bool, db: Optional[AsyncSession]
243-
) -> Optional[Set[str]]:
244+
) -> Optional[AccessTags]:
244245
if not authenticated:
245246
# Tiled is in a "single user" mode with only one API key.
246247
# In this mode, there is no meaningful access tag limit.
@@ -268,7 +269,7 @@ async def get_current_access_tags(
268269
db_factory: Callable[[], Optional[AsyncSession]] = Depends(
269270
get_database_session_factory
270271
),
271-
) -> Optional[Set[str]]:
272+
) -> Optional[AccessTags]:
272273
if api_key is not None:
273274
async with db_factory() as db:
274275
return await get_access_tags_from_api_key(
@@ -302,7 +303,7 @@ async def get_current_access_tags_websocket(
302303
db_factory: Callable[[], Optional[AsyncSession]] = Depends(
303304
get_database_session_factory
304305
),
305-
) -> Optional[Set[str]]:
306+
) -> Optional[AccessTags]:
306307
if api_key is not None:
307308
async with db_factory() as db:
308309
return await get_access_tags_from_api_key(

0 commit comments

Comments
 (0)