diff --git a/CHANGELOG.md b/CHANGELOG.md index 0633ff624..8a4ecc568 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ Write the date in place of the "Unreleased" in the case a new version is release - Enable Tiled server to accept bearer access tokens for authentication -## Unreleased +## v0.2.0 (Unreleased) ### Added @@ -18,6 +18,10 @@ Write the date in place of the "Unreleased" in the case a new version is release - Column names in `TableStructure` are explicitly converted to strings. - Ensure that structural dtype arrays read with `CSVAdapter` have two dimensions, `(n, 1)`. +### Refactored + +- Use common base type for all access policy types + ## v0.1.6 (2025-09-29) diff --git a/tiled/_tests/test_protocols.py b/tiled/_tests/test_protocols.py index 6eea5204a..23aeca560 100644 --- a/tiled/_tests/test_protocols.py +++ b/tiled/_tests/test_protocols.py @@ -10,10 +10,10 @@ from pytest_mock import MockFixture from ..access_control.access_policies import ALL_ACCESS +from ..access_control.protocols import AccessPolicy from ..access_control.scopes import ALL_SCOPES from ..adapters.awkward_directory_container import DirectoryContainer from ..adapters.protocols import ( - AccessPolicy, ArrayAdapter, AwkwardAdapter, BaseAdapter, @@ -28,7 +28,7 @@ from ..structures.core import Spec, StructureFamily from ..structures.sparse import COOStructure from ..structures.table import TableStructure -from ..type_aliases import JSON, Filters, Scopes +from ..type_aliases import JSON, AccessBlob, AccessTags, Filters, Scopes class CustomArrayAdapter: @@ -379,11 +379,30 @@ def __init__(self, scopes: Optional[Scopes] = None) -> None: def _get_id(self, principal: Principal) -> None: return None + async def init_node( + self, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob] = None, + ) -> Tuple[bool, Optional[AccessBlob]]: + return (False, access_blob) + + async def modify_node( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob] = None, + ) -> Tuple[bool, Optional[AccessBlob]]: + return (False, access_blob) + async def allowed_scopes( self, node: BaseAdapter, principal: Principal, - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, ) -> Scopes: allowed = self.scopes @@ -394,7 +413,7 @@ async def filters( self, node: BaseAdapter, principal: Principal, - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, scopes: Scopes, ) -> Filters: @@ -407,7 +426,7 @@ async def accesspolicy_protocol_functions( policy: AccessPolicy, node: BaseAdapter, principal: Principal, - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, scopes: Scopes, ) -> None: diff --git a/tiled/access_control/access_policies.py b/tiled/access_control/access_policies.py index af8432f9f..9350a6e70 100644 --- a/tiled/access_control/access_policies.py +++ b/tiled/access_control/access_policies.py @@ -1,8 +1,13 @@ import logging import os +from typing import Optional, Tuple +from ..adapters.protocols import BaseAdapter from ..queries import AccessBlobFilter +from ..server.schemas import Principal +from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes from ..utils import Sentinel, import_object +from .protocols import AccessPolicy from .scopes import ALL_SCOPES, PUBLIC_SCOPES ALL_ACCESS = Sentinel("ALL_ACCESS") @@ -20,17 +25,42 @@ logger.setLevel(log_level.upper()) -class DummyAccessPolicy: +class DummyAccessPolicy(AccessPolicy): "Impose no access restrictions." - async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes): + async def init_node( + self, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob] = None, + ) -> Tuple[bool, AccessBlob]: + "Do nothing; there is no persistent state to initialize." + return (False, access_blob) + + async def allowed_scopes( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + ) -> Scopes: + "Always allow all scopes." return ALL_SCOPES - async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes): + async def filters( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + scopes: Scopes, + ) -> Filters: + "Always impose no filtering on results." return [] -class TagBasedAccessPolicy: +class TagBasedAccessPolicy(AccessPolicy): def __init__( self, *, @@ -73,8 +103,12 @@ def _is_admin(self, authn_scopes): return False async def init_node( - self, principal, authn_access_tags, authn_scopes, access_blob=None - ): + self, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob] = None, + ) -> Tuple[bool, AccessBlob]: if principal.type == "service": identifier = str(principal.uuid) else: @@ -156,8 +190,13 @@ async def init_node( return access_blob_modified, access_blob_from_policy async def modify_node( - self, node, principal, authn_access_tags, authn_scopes, access_blob - ): + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob], + ) -> Tuple[bool, AccessBlob]: if principal.type == "service": identifier = str(principal.uuid) else: @@ -278,7 +317,13 @@ async def modify_node( # modified means the blob to-be-used was changed in comparison to the user input return access_blob_modified, access_blob_from_policy - async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes): + async def allowed_scopes( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + ) -> Scopes: # If this is being called, filter_for_access has let us get this far. # However, filters and allowed_scopes should always be implemented to # give answers consistent with each other. @@ -317,7 +362,14 @@ async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes) return allowed - async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes): + async def filters( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + scopes: Scopes, + ) -> Filters: queries = [] query_filter = AccessBlobFilter diff --git a/tiled/access_control/protocols.py b/tiled/access_control/protocols.py new file mode 100644 index 000000000..3733779dc --- /dev/null +++ b/tiled/access_control/protocols.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +from ..adapters.protocols import BaseAdapter +from ..server.schemas import Principal +from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes + + +class AccessPolicy(ABC): + @abstractmethod + async def init_node( + self, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob] = None, + ) -> Tuple[bool, Optional[AccessBlob]]: + pass + + async def modify_node( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + access_blob: Optional[AccessBlob], + ) -> Tuple[bool, Optional[AccessBlob]]: + return (False, access_blob) + + @abstractmethod + async def allowed_scopes( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + ) -> Scopes: + pass + + @abstractmethod + async def filters( + self, + node: BaseAdapter, + principal: Principal, + authn_access_tags: Optional[AccessTags], + authn_scopes: Scopes, + scopes: Scopes, + ) -> Filters: + pass diff --git a/tiled/adapters/protocols.py b/tiled/adapters/protocols.py index 22b50e4c4..fd7028f2c 100644 --- a/tiled/adapters/protocols.py +++ b/tiled/adapters/protocols.py @@ -8,14 +8,13 @@ from numpy.typing import NDArray from ..ndslice import NDSlice -from ..server.schemas import Principal from ..storage import Storage from ..structures.array import ArrayStructure from ..structures.awkward import AwkwardStructure from ..structures.core import Spec, StructureFamily from ..structures.sparse import SparseStructure from ..structures.table import TableStructure -from ..type_aliases import JSON, Filters, Scopes +from ..type_aliases import JSON from .awkward_directory_container import DirectoryContainer @@ -130,26 +129,3 @@ def __getitem__(self, key: str) -> ArrayAdapter: AnyAdapter = Union[ ArrayAdapter, AwkwardAdapter, ContainerAdapter, SparseAdapter, TableAdapter ] - - -class AccessPolicy(Protocol): - @abstractmethod - async def allowed_scopes( - self, - node: BaseAdapter, - principal: Principal, - authn_access_tags: Optional[Set[str]], - authn_scopes: Scopes, - ) -> Scopes: - pass - - @abstractmethod - async def filters( - self, - node: BaseAdapter, - principal: Principal, - authn_access_tags: Optional[Set[str]], - authn_scopes: Scopes, - scopes: Scopes, - ) -> Filters: - pass diff --git a/tiled/server/app.py b/tiled/server/app.py index 64335cf78..1cb605d1f 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -36,11 +36,8 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) -from tiled.authenticators import ProxiedOIDCAuthenticator -from tiled.query_registration import QueryRegistry, default_query_registry -from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator -from tiled.type_aliases import AppTask, TaskMap - +from ..access_control.protocols import AccessPolicy +from ..authenticators import ProxiedOIDCAuthenticator from ..catalog.adapter import WouldDeleteData from ..config import ( Authentication, @@ -57,10 +54,13 @@ default_deserialization_registry, default_serialization_registry, ) +from ..query_registration import QueryRegistry, default_query_registry +from ..type_aliases import AppTask, TaskMap from ..utils import SHARE_TILED_PATH, Conflicts, UnsupportedQueryType from ..validation_registration import ValidationRegistry, default_validation_registry from .authentication import move_api_key from .compression import CompressionMiddleware +from .protocols import ExternalAuthenticator, InternalAuthenticator from .router import get_metrics_router, get_router from .settings import Settings, get_settings from .utils import API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, get_root_url, record_timing @@ -125,7 +125,7 @@ def build_app( validation_registry: Optional[ValidationRegistry] = None, tasks: Optional[dict[str, list[AppTask]]] = None, scalable=False, - access_policy=None, + access_policy: Optional[AccessPolicy] = None, ): """ Serve a Tree @@ -137,7 +137,7 @@ def build_app( Dict of authentication configuration. server_settings: dict, optional Dict of other server configuration. - access_policy: + access_policy: AccessPolicy, optional AccessPolicy object encoding rules for which users can see which entries. """ authentication = authentication or Authentication() diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index ac8acc4c2..fb992ef04 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Annotated, Any, Callable, List, Optional, Sequence, Set +from typing import Annotated, Any, Callable, List, Optional, Sequence from fastapi import ( APIRouter, @@ -61,6 +61,7 @@ lookup_valid_pending_session_by_user_code, lookup_valid_session, ) +from ..type_aliases import AccessTags from ..utils import SHARE_TILED_PATH, SingleUserPrincipal from . import schemas from .connection_pool import get_database_session_factory @@ -240,7 +241,7 @@ async def get_session_state(decoded_access_token=Depends(get_decoded_access_toke async def get_access_tags_from_api_key( api_key: str, authenticated: bool, db: Optional[AsyncSession] -) -> Optional[Set[str]]: +) -> Optional[AccessTags]: if not authenticated: # Tiled is in a "single user" mode with only one API key. # In this mode, there is no meaningful access tag limit. @@ -268,7 +269,7 @@ async def get_current_access_tags( db_factory: Callable[[], Optional[AsyncSession]] = Depends( get_database_session_factory ), -) -> Optional[Set[str]]: +) -> Optional[AccessTags]: if api_key is not None: async with db_factory() as db: return await get_access_tags_from_api_key( @@ -302,7 +303,7 @@ async def get_current_access_tags_websocket( db_factory: Callable[[], Optional[AsyncSession]] = Depends( get_database_session_factory ), -) -> Optional[Set[str]]: +) -> Optional[AccessTags]: if api_key is not None: async with db_factory() as db: return await get_access_tags_from_api_key( diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index b0772effa..2a37cc66a 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,16 +1,16 @@ -from typing import List, Optional, Set +from typing import List, Optional import pydantic_settings from fastapi import HTTPException, Query, Request from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_410_GONE -from tiled.adapters.protocols import AnyAdapter -from tiled.server.schemas import Principal -from tiled.structures.core import StructureFamily - -from ..type_aliases import Scopes +from ..access_control.protocols import AccessPolicy +from ..adapters.protocols import AnyAdapter +from ..structures.core import StructureFamily +from ..type_aliases import AccessTags, Scopes from ..utils import BrokenLink from .core import NoEntry +from .schemas import Principal from .utils import filter_for_access, record_timing @@ -22,13 +22,13 @@ async def get_entry( path: str, security_scopes: List[str], principal: Optional[Principal], - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, root_tree: pydantic_settings.BaseSettings, session_state: dict, metrics: dict, structure_families: Optional[set[StructureFamily]] = None, - access_policy=None, + access_policy: Optional[AccessPolicy] = None, ) -> AnyAdapter: """ Obtain a node in the tree from its path. @@ -40,7 +40,6 @@ async def get_entry( """ path_parts = [segment for segment in path.split("/") if segment] entry = root_tree - # access_policy = getattr(request.app.state, "access_policy", None) # If the entry/adapter can take a session state, pass it in. # The entry/adapter may return itself or a different object. if hasattr(entry, "with_session_state") and session_state: diff --git a/tiled/server/router.py b/tiled/server/router.py index 108511fc5..2d4c05684 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone from functools import cache, partial from pathlib import Path -from typing import Callable, List, Optional, Set, TypeVar, Union +from typing import Callable, List, Optional, TypeVar, Union import anyio import packaging @@ -51,7 +51,7 @@ from .. import __version__ from ..ndslice import NDSlice from ..structures.core import Spec, StructureFamily -from ..type_aliases import Scopes +from ..type_aliases import AccessTags, Scopes from ..utils import BrokenLink, ensure_awaitable, patch_mimetypes, path_from_uri from ..validation_registration import ValidationError, ValidationRegistry from . import schemas @@ -299,7 +299,7 @@ async def search( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), settings: Settings = Depends(get_settings), _=Security(check_scopes, scopes=["read:metadata"]), @@ -385,7 +385,7 @@ async def distinct( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:metadata"]), **filters, @@ -436,7 +436,7 @@ async def metadata( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), settings: Settings = Depends(get_settings), _=Security(check_scopes, scopes=["read:metadata"]), @@ -500,7 +500,7 @@ async def array_block( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -591,7 +591,7 @@ async def array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -660,7 +660,7 @@ async def close_stream( principal: Optional[schemas.Principal] = Depends(get_current_principal), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -687,7 +687,7 @@ async def websocket_endpoint( principal: Optional[schemas.Principal] = Depends( get_current_principal_websocket ), - authn_access_tags: Optional[Set[str]] = Depends( + authn_access_tags: Optional[AccessTags] = Depends( get_current_access_tags_websocket ), authn_scopes: Scopes = Depends(get_current_scopes_websocket), @@ -740,7 +740,7 @@ async def get_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -804,7 +804,7 @@ async def post_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -899,7 +899,7 @@ async def get_table_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -942,7 +942,7 @@ async def post_table_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1024,7 +1024,7 @@ async def get_container_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -1068,7 +1068,7 @@ async def post_container_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -1107,7 +1107,7 @@ async def container_full( request: Request, entry, principal: Optional[Principal], - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, field: Optional[List[str]], format: Optional[str], @@ -1161,7 +1161,7 @@ async def node_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -1250,7 +1250,7 @@ async def get_awkward_buffers( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1300,7 +1300,7 @@ async def post_awkward_buffers( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1391,7 +1391,7 @@ async def awkward_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1451,7 +1451,7 @@ async def post_metadata( body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -1498,7 +1498,7 @@ async def post_register( body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -1534,7 +1534,7 @@ async def _create_node( settings: Settings, entry, principal: Optional[Principal], - authn_access_tags: Optional[Set[str]], + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, ): metadata, structure_family, specs, access_blob = ( @@ -1611,7 +1611,7 @@ async def put_data_source( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:metadata", "register"]), ): @@ -1646,7 +1646,7 @@ async def delete( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data", "write:metadata"]), ): @@ -1678,7 +1678,7 @@ async def put_array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1718,7 +1718,7 @@ async def put_array_block( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1760,7 +1760,7 @@ async def patch_array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1798,7 +1798,7 @@ async def put_node_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1836,7 +1836,7 @@ async def put_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1874,7 +1874,7 @@ async def patch_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1911,7 +1911,7 @@ async def put_awkward_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1951,7 +1951,7 @@ async def patch_metadata( body: schemas.PatchMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), drop_revision: bool = False, root_tree=Depends(get_root_tree), @@ -2021,14 +2021,11 @@ async def patch_metadata( settings=settings, ) - if request.app.state.access_policy is not None and hasattr( - request.app.state.access_policy, "modify_node" + if (policy := request.app.state.access_policy) and hasattr( + policy, "modify_node" ): try: - ( - access_blob_modified, - access_blob, - ) = await request.app.state.access_policy.modify_node( + (access_blob_modified, access_blob) = await policy.modify_node( entry, principal, authn_access_tags, authn_scopes, access_blob ) except ValueError as e: @@ -2062,7 +2059,7 @@ async def put_metadata( body: schemas.PutMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), drop_revision: bool = False, root_tree=Depends(get_root_tree), @@ -2100,14 +2097,11 @@ async def put_metadata( settings=settings, ) - if request.app.state.access_policy is not None and hasattr( - request.app.state.access_policy, "modify_node" + if (policy := request.app.state.access_policy) and hasattr( + policy, "modify_node" ): try: - ( - access_blob_modified, - access_blob, - ) = await request.app.state.access_policy.modify_node( + (access_blob_modified, access_blob) = await policy.modify_node( entry, principal, authn_access_tags, authn_scopes, access_blob ) except ValueError as e: @@ -2145,7 +2139,7 @@ async def get_revisions( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:metadata"]), ): @@ -2187,7 +2181,7 @@ async def delete_revision( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:metadata"]), ): @@ -2227,7 +2221,7 @@ async def get_asset( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -2341,7 +2335,7 @@ async def get_asset_manifest( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): diff --git a/tiled/server/utils.py b/tiled/server/utils.py index 16b18fe32..564c44a3e 100644 --- a/tiled/server/utils.py +++ b/tiled/server/utils.py @@ -7,10 +7,10 @@ from starlette.types import Scope from ..access_control.access_policies import NO_ACCESS +from ..access_control.protocols import AccessPolicy from ..adapters.mapping import MapAdapter -from ..adapters.protocols import AccessPolicy from ..server.schemas import Principal -from ..type_aliases import Scopes +from ..type_aliases import AccessTags, Scopes EMPTY_NODE = MapAdapter({}) API_KEY_COOKIE_NAME = "tiled_api_key" @@ -89,7 +89,7 @@ async def filter_for_access( entry, access_policy: Optional[AccessPolicy], principal: Principal, - authn_access_tags, + authn_access_tags: Optional[AccessTags], authn_scopes: Scopes, scopes: Sequence[str], metrics: dict[str, Any], diff --git a/tiled/server/zarr.py b/tiled/server/zarr.py index 914af7415..85dcd4f70 100644 --- a/tiled/server/zarr.py +++ b/tiled/server/zarr.py @@ -1,6 +1,6 @@ import json import re -from typing import Optional, Set, Tuple, Union +from typing import Optional, Tuple, Union import numcodecs import orjson @@ -10,7 +10,7 @@ from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from ..structures.core import StructureFamily -from ..type_aliases import Scopes +from ..type_aliases import AccessTags, Scopes from ..utils import ensure_awaitable from .authentication import ( get_current_access_tags, @@ -56,7 +56,7 @@ async def get_zarr_attrs( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -92,7 +92,7 @@ async def get_zarr_group_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -120,7 +120,7 @@ async def get_zarr_array_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -164,7 +164,7 @@ async def get_zarr_array( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -283,7 +283,7 @@ async def get_zarr_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -375,7 +375,7 @@ async def get_zarr_array( path: str, block: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -456,7 +456,7 @@ async def get_zarr_group( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), - authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_access_tags: Optional[AccessTags] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), diff --git a/tiled/type_aliases.py b/tiled/type_aliases.py index 4d65855dd..0f98e764c 100644 --- a/tiled/type_aliases.py +++ b/tiled/type_aliases.py @@ -30,6 +30,8 @@ Scopes = Set[str] Query = Any # for now... Filters = List[Query] +AccessBlob = Mapping[str, Any] +AccessTags = Set[str] AppTask = Callable[[], Coroutine[None, None, Any]] """Async function to be run as part of the app's lifecycle"""