Skip to content

Commit 7d6c686

Browse files
committed
WIP extracting the authenticator model changes from #928
1 parent 59cd577 commit 7d6c686

File tree

3 files changed

+85
-176
lines changed

3 files changed

+85
-176
lines changed

tiled/_tests/test_access_control.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
2+
import secrets
3+
from typing import Any, TypedDict
24

35
import numpy
46
import pytest
57
from starlette.status import HTTP_403_FORBIDDEN
68

79
from tiled.authenticators import DictionaryAuthenticator
8-
from tiled.server.protocols import UserSessionState
10+
from tiled.server.protocols import Authenticator, InternalAuthenticator, UserSessionState
911

1012
from ..access_policies import NO_ACCESS
1113
from ..adapters.array import ArrayAdapter
@@ -530,24 +532,19 @@ def test_service_principal_access(tmpdir):
530532
sp_client = from_context(context)
531533
assert list(sp_client) == ["x"]
532534

535+
UserAttributes = TypedDict("UserAttributes", {"password": str, "attributes": Any}, total=False)
533536

534-
class CustomAttributesAuthenticator(DictionaryAuthenticator):
537+
class CustomAttributesAuthenticator(InternalAuthenticator):
535538
"""An example authenticator that enriches the stored user information."""
536-
537-
def __init__(self, users: dict, confirmation_message: str = ""):
538-
self._users = users
539-
super().__init__(
540-
{username: user["password"] for username, user in users.items()},
541-
confirmation_message,
542-
)
539+
users: dict[str, UserAttributes] = {}
543540

544541
async def authenticate(self, username, password):
545-
state = await super().authenticate(username, password)
546-
if isinstance(state, UserSessionState):
547-
# enrich the auth state
548-
state.state["attributes"] = self._users[username].get("attributes", {})
549-
return state
550-
542+
print(f'Authenticating: {username=}, {password=}, {self.users=}')
543+
if (attrs := self.users.get(username)) and (pw := attrs.get("password")):
544+
if secrets.compare_digest(pw, password):
545+
state = UserSessionState(username, attrs.get("attributes", {}))
546+
print(f"{state=}")
547+
return state
551548

552549
class CustomAttributesAccessPolicy:
553550
"""

tiled/authenticators.py

Lines changed: 57 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import re
66
import secrets
77
from collections.abc import Iterable
8-
from typing import Any, Mapping, Optional, cast
8+
from typing import Annotated, Any, Mapping, Optional, TypeVar, Union, cast
99

1010
import httpx
1111
from fastapi import APIRouter, Request
1212
from jose import JWTError, jwt
13-
from pydantic import Secret
13+
from pydantic import BeforeValidator, Field, Secret, ConfigDict
1414
from starlette.responses import RedirectResponse
1515

1616
from .server.protocols import (
@@ -32,9 +32,6 @@ class DummyAuthenticator(InternalAuthenticator):
3232
3333
"""
3434

35-
def __init__(self, confirmation_message: str = ""):
36-
self.confirmation_message = confirmation_message
37-
3835
async def authenticate(self, username: str, password: str) -> UserSessionState:
3936
return UserSessionState(username, {})
4037

@@ -46,31 +43,12 @@ class DictionaryAuthenticator(InternalAuthenticator):
4643
Check passwords from a dictionary of usernames mapped to passwords.
4744
"""
4845

49-
configuration_schema = """
50-
$schema": http://json-schema.org/draft-07/schema#
51-
type: object
52-
additionalProperties: false
53-
properties:
54-
users_to_password:
55-
type: object
56-
description: |
57-
Mapping usernames to password. Environment variable expansion should be
58-
used to avoid placing passwords directly in configuration.
59-
confirmation_message:
60-
type: string
61-
description: May be displayed by client after successful login.
62-
"""
63-
64-
def __init__(
65-
self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""
66-
):
67-
self._users_to_passwords = users_to_passwords
68-
self.confirmation_message = confirmation_message
46+
users_to_passwords: Mapping[str, str]
6947

7048
async def authenticate(
7149
self, username: str, password: str
7250
) -> Optional[UserSessionState]:
73-
true_password = self._users_to_passwords.get(username)
51+
true_password = self.users_to_passwords.get(username)
7452
if not true_password:
7553
# Username is not valid.
7654
return
@@ -79,26 +57,13 @@ async def authenticate(
7957

8058

8159
class PAMAuthenticator(InternalAuthenticator):
82-
configuration_schema = """
83-
$schema": http://json-schema.org/draft-07/schema#
84-
type: object
85-
additionalProperties: false
86-
properties:
87-
service:
88-
type: string
89-
description: PAM service. Default is 'login'.
90-
confirmation_message:
91-
type: string
92-
description: May be displayed by client after successful login.
93-
"""
94-
95-
def __init__(self, service: str = "login", confirmation_message: str = ""):
60+
service: str = "login"
61+
62+
def model_post_init(self, __context: Any):
9663
if not modules_available("pamela"):
9764
raise ModuleNotFoundError(
9865
"This PAMAuthenticator requires the module 'pamela' to be installed."
9966
)
100-
self.service = service
101-
self.confirmation_message = confirmation_message
10267
# TODO Try to open a PAM session.
10368

10469
async def authenticate(
@@ -115,47 +80,17 @@ async def authenticate(
11580

11681

11782
class OIDCAuthenticator(ExternalAuthenticator):
118-
configuration_schema = """
119-
$schema": http://json-schema.org/draft-07/schema#
120-
type: object
121-
additionalProperties: false
122-
properties:
123-
audience:
124-
type: string
125-
client_id:
126-
type: string
127-
client_secret:
128-
type: string
129-
well_known_uri:
130-
type: string
131-
confirmation_message:
132-
type: string
133-
"""
134-
135-
def __init__(
136-
self,
137-
audience: str,
138-
client_id: str,
139-
client_secret: str,
140-
well_known_uri: str,
141-
confirmation_message: str = "",
142-
):
143-
self._audience = audience
144-
self._client_id = client_id
145-
self._client_secret = Secret(client_secret)
146-
self._well_known_url = well_known_uri
147-
self.confirmation_message = confirmation_message
83+
audience: str
84+
client_id: str
85+
client_secret: Secret[str]
86+
well_known_uri: str
14887

14988
@functools.cached_property
15089
def _config_from_oidc_url(self) -> dict[str, Any]:
151-
response: httpx.Response = httpx.get(self._well_known_url)
90+
response: httpx.Response = httpx.get(self.well_known_url)
15291
response.raise_for_status()
15392
return response.json()
15493

155-
@functools.cached_property
156-
def client_id(self) -> str:
157-
return self._client_id
158-
15994
@functools.cached_property
16095
def id_token_signing_alg_values_supported(self) -> list[str]:
16196
return cast(
@@ -190,8 +125,8 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
190125
response = await exchange_code(
191126
self.token_endpoint,
192127
code,
193-
self._client_id,
194-
self._client_secret.get_secret_value(),
128+
self.client_id,
129+
self.client_secret.get_secret_value(),
195130
redirect_uri,
196131
)
197132
response_body = response.json()
@@ -207,7 +142,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
207142
token=id_token,
208143
key=keys,
209144
algorithms=self.id_token_signing_alg_values_supported,
210-
audience=self._audience,
145+
audience=self.audience,
211146
access_token=access_token,
212147
)
213148
except JWTError:
@@ -338,6 +273,18 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st
338273
return rv
339274

340275

276+
T = TypeVar("T")
277+
278+
279+
def one_or_many(value: Union[T, list[T]]) -> list[T]:
280+
if isinstance(value, str) or not isinstance(value, Iterable):
281+
return [value]
282+
return list(value)
283+
284+
285+
OneOrMany = BeforeValidator(one_or_many)
286+
287+
341288
class LDAPAuthenticator(InternalAuthenticator):
342289
"""
343290
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
@@ -396,7 +343,7 @@ class LDAPAuthenticator(InternalAuthenticator):
396343
"uid={username},ou=people,dc=wikimedia,dc=org",
397344
"uid={username},ou=Developers,dc=wikimedia,dc=org"
398345
]
399-
allowed_groups: list or None
346+
allowed_groups: list
400347
List of LDAP group DNs that users could be members of to be granted access.
401348
402349
If a user is in any one of the listed groups, then that user is granted access.
@@ -521,79 +468,35 @@ class LDAPAuthenticator(InternalAuthenticator):
521468
id: user02
522469
"""
523470

524-
def __init__(
525-
self,
526-
server_address,
527-
server_port=None,
528-
*,
529-
use_ssl=False,
530-
use_tls=True,
531-
connect_timeout=5,
532-
receive_timeout=60,
533-
bind_dn_template=None,
534-
allowed_groups=None,
535-
valid_username_regex=r"^[a-z][.a-z0-9_-]*$",
536-
lookup_dn=False,
537-
user_search_base=None,
538-
user_attribute=None,
539-
lookup_dn_search_filter="({login_attr}={login})",
540-
lookup_dn_search_user=None,
541-
lookup_dn_search_password=None,
542-
lookup_dn_user_dn_attribute=None,
543-
escape_userdn=False,
544-
search_filter="",
545-
attributes=None,
546-
auth_state_attributes=None,
547-
use_lookup_dn_username=True,
548-
confirmation_message="",
549-
):
550-
self.use_ssl = use_ssl
551-
self.use_tls = use_tls
552-
self.connect_timeout = connect_timeout
553-
self.receive_timeout = receive_timeout
554-
self.bind_dn_template = bind_dn_template
555-
self.allowed_groups = allowed_groups
556-
self.valid_username_regex = valid_username_regex
557-
self.lookup_dn = lookup_dn
558-
self.user_search_base = user_search_base
559-
self.user_attribute = user_attribute
560-
self.lookup_dn_search_filter = lookup_dn_search_filter
561-
self.lookup_dn_search_user = lookup_dn_search_user
562-
self.lookup_dn_search_password = lookup_dn_search_password
563-
self.lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute
564-
self.escape_userdn = escape_userdn
565-
self.search_filter = search_filter
566-
self.attributes = attributes if attributes else []
567-
self.auth_state_attributes = (
568-
auth_state_attributes if auth_state_attributes else []
569-
)
570-
self.use_lookup_dn_username = use_lookup_dn_username
571-
572-
if isinstance(server_address, str):
573-
server_address_list = [server_address]
574-
elif isinstance(server_address, Iterable):
575-
server_address_list = list(server_address)
576-
else:
577-
raise TypeError(
578-
f"Unsupported type of `server_address` (list): server_address={server_address} "
579-
f"type(server_address)={type(server_address)}"
580-
)
581-
if not server_address_list:
582-
raise ValueError(
583-
"No servers are specified: 'server_address' is an empty list"
584-
)
585-
586-
self.server_address_list = server_address_list
587-
self.server_port = (
588-
server_port if server_port is not None else self._server_port_default()
589-
)
590-
self.confirmation_message = confirmation_message
591-
592-
def _server_port_default(self):
593-
if self.use_ssl:
594-
return 636 # default SSL port for LDAP
595-
else:
596-
return 389 # default plaintext port for LDAP
471+
model_config = ConfigDict(use_attribute_docstrings=True)
472+
473+
server_address_list: Annotated[list[str], OneOrMany]
474+
server_port: Annotated[
475+
int,
476+
Field(
477+
default_factory=lambda data: port
478+
if (port := data["server_port"] is not None)
479+
else (636 if data["use_ssl"] else 389)
480+
),
481+
]
482+
use_ssl: bool = False
483+
use_tls: bool = True
484+
connect_timeout: float = 5.0
485+
receive_timeout: float = 60.0
486+
bind_dn_template: Annotated[list[str], OneOrMany] = []
487+
allowed_groups: list[str] = []
488+
lookup_dn: bool
489+
user_search_base: str
490+
user_attribute: str
491+
lookup_dn_search_filter: Optional[str] = "({login_attr}={login})"
492+
lookup_dn_search_user: Optional[str]
493+
lookup_dn_search_password: Optional[str]
494+
lookup_dn_user_dn_attribute: Optional[str]
495+
escape_userdn: bool
496+
search_filter: str
497+
attributes: list[str] = []
498+
auth_state_attributes: list[str] = []
499+
use_lookup_dn_username: bool
597500

598501
async def resolve_username(self, username_supplied_by_user):
599502
import ldap3

tiled/server/protocols.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
33
from typing import Optional
44

55
from fastapi import Request
6+
from pydantic import BaseModel
67

78

89
@dataclass
@@ -13,11 +14,19 @@ class UserSessionState:
1314
state: dict = None
1415

1516

16-
class InternalAuthenticator(ABC):
17-
def authenticate(self, username: str, password: str) -> Optional[UserSessionState]:
18-
raise NotImplementedError
17+
class Authenticator(BaseModel, ABC):
18+
confirmation_message: str = ""
1919

2020

21-
class ExternalAuthenticator(ABC):
22-
def authenticate(self, request: Request) -> Optional[UserSessionState]:
23-
raise NotImplementedError
21+
class InternalAuthenticator(Authenticator):
22+
@abstractmethod
23+
async def authenticate(
24+
self, username: str, password: str
25+
) -> Optional[UserSessionState]:
26+
...
27+
28+
29+
class ExternalAuthenticator(Authenticator):
30+
@abstractmethod
31+
async def authenticate(self, request: Request) -> Optional[UserSessionState]:
32+
...

0 commit comments

Comments
 (0)