Skip to content

Commit 80f7d2b

Browse files
committed
Use common base type for all access policy types
1 parent 59cd577 commit 80f7d2b

File tree

6 files changed

+34
-26
lines changed

6 files changed

+34
-26
lines changed

tiled/access_policies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from contextlib import closing
55
from functools import partial
66

7+
from tiled.adapters.protocols import AccessPolicy
8+
79
from .queries import AccessBlobFilter, In, KeysFilter
810
from .scopes import ALL_SCOPES, PUBLIC_SCOPES
911
from .utils import Sentinel, SpecialUsers, import_object
@@ -23,7 +25,7 @@
2325
logger.setLevel(log_level.upper())
2426

2527

26-
class DummyAccessPolicy:
28+
class DummyAccessPolicy(AccessPolicy):
2729
"Impose no access restrictions."
2830

2931
async def allowed_scopes(self, node, principal, authn_scopes):
@@ -33,7 +35,7 @@ async def filters(self, node, principal, authn_scopes, scopes):
3335
return []
3436

3537

36-
class SimpleAccessPolicy:
38+
class SimpleAccessPolicy(AccessPolicy):
3739
"""
3840
A mapping of user names to lists of entries they have access to.
3941
@@ -180,7 +182,7 @@ def get_tags_from_scope(self, scope, username):
180182
return user_scope_tags
181183

182184

183-
class TagBasedAccessPolicy:
185+
class TagBasedAccessPolicy(AccessPolicy):
184186
def __init__(
185187
self,
186188
*,

tiled/adapters/protocols.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import abstractmethod
1+
from abc import ABC, abstractmethod
22
from collections.abc import Mapping
33
from typing import Any, Dict, List, Literal, Optional, Protocol, Set, Tuple, Union
44

@@ -132,7 +132,7 @@ def __getitem__(self, key: str) -> ArrayAdapter:
132132
]
133133

134134

135-
class AccessPolicy(Protocol):
135+
class AccessPolicy(ABC):
136136
@abstractmethod
137137
async def allowed_scopes(
138138
self,
@@ -151,3 +151,12 @@ async def filters(
151151
scopes: Scopes,
152152
) -> Filters:
153153
pass
154+
155+
async def modify_node(
156+
self,
157+
node: BaseAdapter,
158+
principal: Principal,
159+
authn_scopes: Scopes,
160+
access_blob: Optional[dict[str, Any]],
161+
) -> tuple[bool, BaseAdapter]:
162+
return (False, node)

tiled/server/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
HTTP_500_INTERNAL_SERVER_ERROR,
3636
)
3737

38+
from tiled.adapters.protocols import AccessPolicy
3839
from tiled.query_registration import QueryRegistry, default_query_registry
3940
from tiled.server.authentication import move_api_key
4041
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
@@ -118,7 +119,7 @@ def build_app(
118119
validation_registry: Optional[ValidationRegistry] = None,
119120
tasks: Optional[dict[str, list[AppTask]]] = None,
120121
scalable=False,
121-
access_policy=None,
122+
access_policy: Optional[AccessPolicy] = None,
122123
):
123124
"""
124125
Serve a Tree
@@ -132,7 +133,7 @@ def build_app(
132133
List of authenticator classes (one per support identity provider)
133134
server_settings: dict, optional
134135
Dict of other server configuration.
135-
access_policy:
136+
access_policy: AccessPolicy, optional
136137
AccessPolicy object encoding rules for which users can see which entries.
137138
"""
138139
authentication = authentication or {}

tiled/server/dependencies.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi import HTTPException, Query
55
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE
66

7-
from tiled.adapters.protocols import AnyAdapter
7+
from tiled.adapters.protocols import AccessPolicy, AnyAdapter
88
from tiled.server.schemas import Principal
99
from tiled.structures.core import StructureFamily
1010
from tiled.utils import SpecialUsers
@@ -31,7 +31,7 @@ async def get_entry(
3131
session_state: dict,
3232
metrics: dict,
3333
structure_families: Optional[set[StructureFamily]] = None,
34-
access_policy=None,
34+
access_policy: Optional[AccessPolicy] = None,
3535
) -> AnyAdapter:
3636
"""
3737
Obtain a node in the tree from its path.
@@ -43,7 +43,6 @@ async def get_entry(
4343
"""
4444
path_parts = [segment for segment in path.split("/") if segment]
4545
entry = root_tree
46-
# access_policy = getattr(request.app.state, "access_policy", None)
4746
# If the entry/adapter can take a session state, pass it in.
4847
# The entry/adapter may return itself or a different object.
4948
if hasattr(entry, "with_session_state") and session_state:

tiled/server/router.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,14 +1931,9 @@ async def patch_metadata(
19311931
settings=settings,
19321932
)
19331933

1934-
if request.app.state.access_policy is not None and hasattr(
1935-
request.app.state.access_policy, "modify_node"
1936-
):
1934+
if policy := request.app.state.access_policy:
19371935
try:
1938-
(
1939-
access_blob_modified,
1940-
access_blob,
1941-
) = await request.app.state.access_policy.modify_node(
1936+
(access_blob_modified, access_blob) = await policy.modify_node(
19421937
entry, principal, authn_scopes, access_blob
19431938
)
19441939
except ValueError as e:
@@ -2013,14 +2008,9 @@ async def put_metadata(
20132008
settings=settings,
20142009
)
20152010

2016-
if request.app.state.access_policy is not None and hasattr(
2017-
request.app.state.access_policy, "modify_node"
2018-
):
2011+
if policy := request.app.state.access_policy:
20192012
try:
2020-
(
2021-
access_blob_modified,
2022-
access_blob,
2023-
) = await request.app.state.access_policy.modify_node(
2013+
(access_blob_modified, access_blob) = await policy.modify_node(
20242014
entry, principal, authn_scopes, access_blob
20252015
)
20262016
except ValueError as e:

tiled/server/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import contextlib
22
import time
33
from collections.abc import Generator
4-
from typing import Any, Mapping
4+
from typing import Any, Mapping, Optional
55

66
from fastapi import Request
77
from starlette.types import Scope
88

9+
from tiled.adapters.protocols import AccessPolicy
10+
911
from ..access_policies import NO_ACCESS
1012
from ..adapters.mapping import MapAdapter
1113

@@ -68,7 +70,12 @@ def get_root_url_low_level(request_headers: Mapping[str, str], scope: Scope) ->
6870

6971

7072
async def filter_for_access(
71-
entry, access_policy, principal, authn_scopes, scopes, metrics
73+
entry,
74+
access_policy: Optional[AccessPolicy],
75+
principal,
76+
authn_scopes,
77+
scopes,
78+
metrics,
7279
):
7380
if access_policy is not None and hasattr(entry, "search"):
7481
with record_timing(metrics, "acl"):

0 commit comments

Comments
 (0)