Skip to content

Commit 9342618

Browse files
committed
Use common base type for all access policy types
1 parent cb7c46f commit 9342618

File tree

7 files changed

+35
-26
lines changed

7 files changed

+35
-26
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Write the date in place of the "Unreleased" in the case a new version is release
1919
of Closure Table to track ancestors and descendands of the nodes.
2020
- Shorter string representation of chunks in `ArrayClient`.
2121
- Refactored internal Zarr version detection
22+
- Use common base type for all access policy types
2223

2324
### Fixed
2425

tiled/access_policies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from contextlib import closing
55
from functools import partial
66

7+
from tiled.adapters.protocols import AccessPolicy
8+
79
from .queries import AccessBlobFilter, In, KeysFilter
810
from .scopes import ALL_SCOPES, PUBLIC_SCOPES
911
from .utils import Sentinel, SpecialUsers, import_object
@@ -23,7 +25,7 @@
2325
logger.setLevel(log_level.upper())
2426

2527

26-
class DummyAccessPolicy:
28+
class DummyAccessPolicy(AccessPolicy):
2729
"Impose no access restrictions."
2830

2931
async def allowed_scopes(self, node, principal, authn_scopes):
@@ -33,7 +35,7 @@ async def filters(self, node, principal, authn_scopes, scopes):
3335
return []
3436

3537

36-
class SimpleAccessPolicy:
38+
class SimpleAccessPolicy(AccessPolicy):
3739
"""
3840
A mapping of user names to lists of entries they have access to.
3941
@@ -180,7 +182,7 @@ def get_tags_from_scope(self, scope, username):
180182
return user_scope_tags
181183

182184

183-
class TagBasedAccessPolicy:
185+
class TagBasedAccessPolicy(AccessPolicy):
184186
def __init__(
185187
self,
186188
*,

tiled/adapters/protocols.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import abstractmethod
1+
from abc import ABC, abstractmethod
22
from collections.abc import Mapping
33
from typing import Any, Dict, List, Literal, Optional, Protocol, Set, Tuple, Union
44

@@ -132,7 +132,7 @@ def __getitem__(self, key: str) -> ArrayAdapter:
132132
]
133133

134134

135-
class AccessPolicy(Protocol):
135+
class AccessPolicy(ABC):
136136
@abstractmethod
137137
async def allowed_scopes(
138138
self,
@@ -151,3 +151,12 @@ async def filters(
151151
scopes: Scopes,
152152
) -> Filters:
153153
pass
154+
155+
async def modify_node(
156+
self,
157+
node: BaseAdapter,
158+
principal: Principal,
159+
authn_scopes: Scopes,
160+
access_blob: Optional[dict[str, Any]],
161+
) -> tuple[bool, Optional[dict[str, Any]]]:
162+
return (False, access_blob)

tiled/server/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
HTTP_500_INTERNAL_SERVER_ERROR,
3737
)
3838

39+
from tiled.adapters.protocols import AccessPolicy
3940
from tiled.query_registration import QueryRegistry, default_query_registry
4041
from tiled.server.authentication import move_api_key
4142
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
@@ -120,7 +121,7 @@ def build_app(
120121
validation_registry: Optional[ValidationRegistry] = None,
121122
tasks: Optional[dict[str, list[AppTask]]] = None,
122123
scalable=False,
123-
access_policy=None,
124+
access_policy: Optional[AccessPolicy] = None,
124125
):
125126
"""
126127
Serve a Tree
@@ -134,7 +135,7 @@ def build_app(
134135
List of authenticator classes (one per support identity provider)
135136
server_settings: dict, optional
136137
Dict of other server configuration.
137-
access_policy:
138+
access_policy: AccessPolicy, optional
138139
AccessPolicy object encoding rules for which users can see which entries.
139140
"""
140141
authentication = authentication or {}

tiled/server/dependencies.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi import HTTPException, Query, Request
55
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE
66

7-
from tiled.adapters.protocols import AnyAdapter
7+
from tiled.adapters.protocols import AccessPolicy, AnyAdapter
88
from tiled.server.schemas import Principal
99
from tiled.structures.core import StructureFamily
1010
from tiled.utils import SpecialUsers
@@ -28,7 +28,7 @@ async def get_entry(
2828
session_state: dict,
2929
metrics: dict,
3030
structure_families: Optional[set[StructureFamily]] = None,
31-
access_policy=None,
31+
access_policy: Optional[AccessPolicy] = None,
3232
) -> AnyAdapter:
3333
"""
3434
Obtain a node in the tree from its path.
@@ -40,7 +40,6 @@ async def get_entry(
4040
"""
4141
path_parts = [segment for segment in path.split("/") if segment]
4242
entry = root_tree
43-
# access_policy = getattr(request.app.state, "access_policy", None)
4443
# If the entry/adapter can take a session state, pass it in.
4544
# The entry/adapter may return itself or a different object.
4645
if hasattr(entry, "with_session_state") and session_state:

tiled/server/router.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,14 +1898,9 @@ async def patch_metadata(
18981898
settings=settings,
18991899
)
19001900

1901-
if request.app.state.access_policy is not None and hasattr(
1902-
request.app.state.access_policy, "modify_node"
1903-
):
1901+
if policy := request.app.state.access_policy:
19041902
try:
1905-
(
1906-
access_blob_modified,
1907-
access_blob,
1908-
) = await request.app.state.access_policy.modify_node(
1903+
(access_blob_modified, access_blob) = await policy.modify_node(
19091904
entry, principal, authn_scopes, access_blob
19101905
)
19111906
except ValueError as e:
@@ -1978,14 +1973,9 @@ async def put_metadata(
19781973
settings=settings,
19791974
)
19801975

1981-
if request.app.state.access_policy is not None and hasattr(
1982-
request.app.state.access_policy, "modify_node"
1983-
):
1976+
if policy := request.app.state.access_policy:
19841977
try:
1985-
(
1986-
access_blob_modified,
1987-
access_blob,
1988-
) = await request.app.state.access_policy.modify_node(
1978+
(access_blob_modified, access_blob) = await policy.modify_node(
19891979
entry, principal, authn_scopes, access_blob
19901980
)
19911981
except ValueError as e:

tiled/server/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import contextlib
22
import time
33
from collections.abc import Generator
4-
from typing import Any, Literal, Mapping
4+
from typing import Any, Literal, Mapping, Optional
55

66
from fastapi import Request
77
from starlette.types import Scope
88

9+
from tiled.adapters.protocols import AccessPolicy
10+
911
from ..access_policies import NO_ACCESS
1012
from ..adapters.mapping import MapAdapter
1113

@@ -75,7 +77,12 @@ def get_root_url_low_level(request_headers: Mapping[str, str], scope: Scope) ->
7577

7678

7779
async def filter_for_access(
78-
entry, access_policy, principal, authn_scopes, scopes, metrics
80+
entry,
81+
access_policy: Optional[AccessPolicy],
82+
principal,
83+
authn_scopes,
84+
scopes,
85+
metrics,
7986
):
8087
if access_policy is not None and hasattr(entry, "search"):
8188
with record_timing(metrics, "acl"):

0 commit comments

Comments
 (0)