Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Write the date in place of the "Unreleased" in the case a new version is release

- Enable Tiled server to accept bearer access tokens for authentication

## Unreleased
## v0.2.0 (Unreleased)

### Added

Expand All @@ -18,6 +18,10 @@ Write the date in place of the "Unreleased" in the case a new version is release
- Column names in `TableStructure` are explicitly converted to strings.
- Ensure that structural dtype arrays read with `CSVAdapter` have two dimensions, `(n, 1)`.

### Refactored

- Use common base type for all access policy types


## v0.1.6 (2025-09-29)

Expand Down
29 changes: 24 additions & 5 deletions tiled/_tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from pytest_mock import MockFixture

from ..access_control.access_policies import ALL_ACCESS
from ..access_control.protocols import AccessPolicy
from ..access_control.scopes import ALL_SCOPES
from ..adapters.awkward_directory_container import DirectoryContainer
from ..adapters.protocols import (
AccessPolicy,
ArrayAdapter,
AwkwardAdapter,
BaseAdapter,
Expand All @@ -28,7 +28,7 @@
from ..structures.core import Spec, StructureFamily
from ..structures.sparse import COOStructure
from ..structures.table import TableStructure
from ..type_aliases import JSON, Filters, Scopes
from ..type_aliases import JSON, AccessBlob, AccessTags, Filters, Scopes


class CustomArrayAdapter:
Expand Down Expand Up @@ -379,11 +379,30 @@ def __init__(self, scopes: Optional[Scopes] = None) -> None:
def _get_id(self, principal: Principal) -> None:
return None

async def init_node(
self,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob] = None,
) -> Tuple[bool, Optional[AccessBlob]]:
return (False, access_blob)

async def modify_node(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob] = None,
) -> Tuple[bool, Optional[AccessBlob]]:
return (False, access_blob)

async def allowed_scopes(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[Set[str]],
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
) -> Scopes:
allowed = self.scopes
Expand All @@ -394,7 +413,7 @@ async def filters(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[Set[str]],
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
scopes: Scopes,
) -> Filters:
Expand All @@ -407,7 +426,7 @@ async def accesspolicy_protocol_functions(
policy: AccessPolicy,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[Set[str]],
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
scopes: Scopes,
) -> None:
Expand Down
72 changes: 62 additions & 10 deletions tiled/access_control/access_policies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
import os
from typing import Optional, Tuple

from ..adapters.protocols import BaseAdapter
from ..queries import AccessBlobFilter
from ..server.schemas import Principal
from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes
from ..utils import Sentinel, import_object
from .protocols import AccessPolicy
from .scopes import ALL_SCOPES, PUBLIC_SCOPES

ALL_ACCESS = Sentinel("ALL_ACCESS")
Expand All @@ -20,17 +25,42 @@
logger.setLevel(log_level.upper())


class DummyAccessPolicy:
class DummyAccessPolicy(AccessPolicy):
"Impose no access restrictions."

async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes):
async def init_node(
self,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob] = None,
) -> Tuple[bool, AccessBlob]:
"Do nothing; there is no persistent state to initialize."
return (False, access_blob)

async def allowed_scopes(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
) -> Scopes:
"Always allow all scopes."
return ALL_SCOPES

async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes):
async def filters(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
scopes: Scopes,
) -> Filters:
"Always impose no filtering on results."
return []


class TagBasedAccessPolicy:
class TagBasedAccessPolicy(AccessPolicy):
def __init__(
self,
*,
Expand Down Expand Up @@ -73,8 +103,12 @@ def _is_admin(self, authn_scopes):
return False

async def init_node(
self, principal, authn_access_tags, authn_scopes, access_blob=None
):
self,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob] = None,
) -> Tuple[bool, AccessBlob]:
if principal.type == "service":
identifier = str(principal.uuid)
else:
Expand Down Expand Up @@ -156,8 +190,13 @@ async def init_node(
return access_blob_modified, access_blob_from_policy

async def modify_node(
self, node, principal, authn_access_tags, authn_scopes, access_blob
):
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob],
) -> Tuple[bool, AccessBlob]:
if principal.type == "service":
identifier = str(principal.uuid)
else:
Expand Down Expand Up @@ -278,7 +317,13 @@ async def modify_node(
# modified means the blob to-be-used was changed in comparison to the user input
return access_blob_modified, access_blob_from_policy

async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes):
async def allowed_scopes(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
) -> Scopes:
# If this is being called, filter_for_access has let us get this far.
# However, filters and allowed_scopes should always be implemented to
# give answers consistent with each other.
Expand Down Expand Up @@ -317,7 +362,14 @@ async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes)

return allowed

async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes):
async def filters(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
scopes: Scopes,
) -> Filters:
queries = []
query_filter = AccessBlobFilter

Expand Down
49 changes: 49 additions & 0 deletions tiled/access_control/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple

from ..adapters.protocols import BaseAdapter
from ..server.schemas import Principal
from ..type_aliases import AccessBlob, AccessTags, Filters, Scopes


class AccessPolicy(ABC):
@abstractmethod
async def init_node(
self,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob] = None,
) -> Tuple[bool, Optional[AccessBlob]]:
pass

async def modify_node(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
access_blob: Optional[AccessBlob],
) -> Tuple[bool, Optional[AccessBlob]]:
return (False, access_blob)

@abstractmethod
async def allowed_scopes(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
) -> Scopes:
pass

@abstractmethod
async def filters(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[AccessTags],
authn_scopes: Scopes,
scopes: Scopes,
) -> Filters:
pass
26 changes: 1 addition & 25 deletions tiled/adapters/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from numpy.typing import NDArray

from ..ndslice import NDSlice
from ..server.schemas import Principal
from ..storage import Storage
from ..structures.array import ArrayStructure
from ..structures.awkward import AwkwardStructure
from ..structures.core import Spec, StructureFamily
from ..structures.sparse import SparseStructure
from ..structures.table import TableStructure
from ..type_aliases import JSON, Filters, Scopes
from ..type_aliases import JSON
from .awkward_directory_container import DirectoryContainer


Expand Down Expand Up @@ -130,26 +129,3 @@ def __getitem__(self, key: str) -> ArrayAdapter:
AnyAdapter = Union[
ArrayAdapter, AwkwardAdapter, ContainerAdapter, SparseAdapter, TableAdapter
]


class AccessPolicy(Protocol):
@abstractmethod
async def allowed_scopes(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[Set[str]],
authn_scopes: Scopes,
) -> Scopes:
pass

@abstractmethod
async def filters(
self,
node: BaseAdapter,
principal: Principal,
authn_access_tags: Optional[Set[str]],
authn_scopes: Scopes,
scopes: Scopes,
) -> Filters:
pass
14 changes: 7 additions & 7 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from tiled.authenticators import ProxiedOIDCAuthenticator
from tiled.query_registration import QueryRegistry, default_query_registry
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
from tiled.type_aliases import AppTask, TaskMap

from ..access_control.protocols import AccessPolicy
from ..authenticators import ProxiedOIDCAuthenticator
from ..catalog.adapter import WouldDeleteData
from ..config import (
Authentication,
Expand All @@ -57,10 +54,13 @@
default_deserialization_registry,
default_serialization_registry,
)
from ..query_registration import QueryRegistry, default_query_registry
from ..type_aliases import AppTask, TaskMap
from ..utils import SHARE_TILED_PATH, Conflicts, UnsupportedQueryType
from ..validation_registration import ValidationRegistry, default_validation_registry
from .authentication import move_api_key
from .compression import CompressionMiddleware
from .protocols import ExternalAuthenticator, InternalAuthenticator
from .router import get_metrics_router, get_router
from .settings import Settings, get_settings
from .utils import API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, get_root_url, record_timing
Expand Down Expand Up @@ -125,7 +125,7 @@ def build_app(
validation_registry: Optional[ValidationRegistry] = None,
tasks: Optional[dict[str, list[AppTask]]] = None,
scalable=False,
access_policy=None,
access_policy: Optional[AccessPolicy] = None,
):
"""
Serve a Tree
Expand All @@ -137,7 +137,7 @@ def build_app(
Dict of authentication configuration.
server_settings: dict, optional
Dict of other server configuration.
access_policy:
access_policy: AccessPolicy, optional
AccessPolicy object encoding rules for which users can see which entries.
"""
authentication = authentication or Authentication()
Expand Down
9 changes: 5 additions & 4 deletions tiled/server/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Annotated, Any, Callable, List, Optional, Sequence, Set
from typing import Annotated, Any, Callable, List, Optional, Sequence

from fastapi import (
APIRouter,
Expand Down Expand Up @@ -61,6 +61,7 @@
lookup_valid_pending_session_by_user_code,
lookup_valid_session,
)
from ..type_aliases import AccessTags
from ..utils import SHARE_TILED_PATH, SingleUserPrincipal
from . import schemas
from .connection_pool import get_database_session_factory
Expand Down Expand Up @@ -240,7 +241,7 @@ async def get_session_state(decoded_access_token=Depends(get_decoded_access_toke

async def get_access_tags_from_api_key(
api_key: str, authenticated: bool, db: Optional[AsyncSession]
) -> Optional[Set[str]]:
) -> Optional[AccessTags]:
if not authenticated:
# Tiled is in a "single user" mode with only one API key.
# In this mode, there is no meaningful access tag limit.
Expand Down Expand Up @@ -268,7 +269,7 @@ async def get_current_access_tags(
db_factory: Callable[[], Optional[AsyncSession]] = Depends(
get_database_session_factory
),
) -> Optional[Set[str]]:
) -> Optional[AccessTags]:
if api_key is not None:
async with db_factory() as db:
return await get_access_tags_from_api_key(
Expand Down Expand Up @@ -302,7 +303,7 @@ async def get_current_access_tags_websocket(
db_factory: Callable[[], Optional[AsyncSession]] = Depends(
get_database_session_factory
),
) -> Optional[Set[str]]:
) -> Optional[AccessTags]:
if api_key is not None:
async with db_factory() as db:
return await get_access_tags_from_api_key(
Expand Down
Loading