Skip to content

Commit 676086c

Browse files
committed
WIP extracting the authenticator model changes from #928
1 parent cb7c46f commit 676086c

File tree

5 files changed

+109
-188
lines changed

5 files changed

+109
-188
lines changed

tiled/_tests/test_access_control.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
2+
import secrets
3+
from typing import Any
24

35
import numpy
46
import pytest
57
from starlette.status import HTTP_403_FORBIDDEN
8+
from typing_extensions import TypedDict
69

7-
from tiled.authenticators import DictionaryAuthenticator
8-
from tiled.server.protocols import UserSessionState
10+
from tiled.server.protocols import InternalAuthenticator, UserSessionState
911

1012
from ..access_policies import NO_ACCESS
1113
from ..adapters.array import ArrayAdapter
@@ -531,22 +533,25 @@ def test_service_principal_access(tmpdir, sqlite_or_postgres_uri):
531533
assert list(sp_client) == ["x"]
532534

533535

534-
class CustomAttributesAuthenticator(DictionaryAuthenticator):
536+
UserAttributes = TypedDict(
537+
"UserAttributes", {"password": str, "attributes": dict[str, Any]}, total=False
538+
)
539+
540+
541+
class CustomAttributesAuthenticator(InternalAuthenticator):
535542
"""An example authenticator that enriches the stored user information."""
536543

537-
def __init__(self, users: dict, confirmation_message: str = ""):
538-
self._users = users
539-
super().__init__(
540-
{username: user["password"] for username, user in users.items()},
541-
confirmation_message,
542-
)
544+
users: dict[str, UserAttributes] = {}
543545

544546
async def authenticate(self, username, password):
545-
state = await super().authenticate(username, password)
546-
if isinstance(state, UserSessionState):
547-
# enrich the auth state
548-
state.state["attributes"] = self._users[username].get("attributes", {})
549-
return state
547+
print(f"Authenticating: {username=}, {password=}, {self.users=}")
548+
if (attrs := self.users.get(username)) and (pw := attrs.get("password")):
549+
if secrets.compare_digest(pw, password):
550+
state = UserSessionState(
551+
username, {"attributes": attrs.get("attributes", {})}
552+
)
553+
print(f"{state=}")
554+
return state
550555

551556

552557
class CustomAttributesAccessPolicy:

tiled/_tests/test_authenticators.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@
2525
])
2626
# fmt: on
2727
@pytest.mark.parametrize("use_tls,use_ssl", [(False, False)])
28+
@pytest.mark.skipif(not TILED_TEST_LDAP, reason="Requires an LDAP container and TILED_TEST_LDAP to be set")
2829
def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port):
2930
"""
3031
Basic test for ``LDAPAuthenticator``.
3132
3233
TODO: The test could be extended with enabled TLS or SSL, but it requires configuration
3334
of the LDAP server.
3435
"""
35-
if not TILED_TEST_LDAP:
36-
pytest.skip("Run an LDAP container and set TILED_TEST_LDAP to run")
37-
authenticator = LDAPAuthenticator(
38-
ldap_server_address,
39-
ldap_server_port,
40-
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
41-
use_tls=use_tls,
42-
use_ssl=use_ssl,
43-
)
36+
37+
params = dict(server_address=ldap_server_address,
38+
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
39+
use_ssl=use_ssl,
40+
use_tls=use_tls)
41+
if ldap_server_port is not None:
42+
params["server_port"] = ldap_server_port
43+
authenticator = LDAPAuthenticator(**params)
4444

4545
async def testing():
4646
assert (await authenticator.authenticate("user01", "password1")).user_name == "user01"

tiled/authenticators.py

Lines changed: 61 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import re
66
import secrets
77
from collections.abc import Iterable
8-
from typing import Any, Mapping, Optional, cast
8+
from typing import Annotated, Any, Mapping, Optional, TypeVar, Union, cast
99

1010
import httpx
1111
from fastapi import APIRouter, Request
1212
from jose import JWTError, jwt
13-
from pydantic import Secret
13+
from pydantic import BeforeValidator, ConfigDict, Field, Secret
1414
from starlette.responses import RedirectResponse
1515

1616
from .server.protocols import (
@@ -32,9 +32,6 @@ class DummyAuthenticator(InternalAuthenticator):
3232
3333
"""
3434

35-
def __init__(self, confirmation_message: str = ""):
36-
self.confirmation_message = confirmation_message
37-
3835
async def authenticate(self, username: str, password: str) -> UserSessionState:
3936
return UserSessionState(username, {})
4037

@@ -46,31 +43,12 @@ class DictionaryAuthenticator(InternalAuthenticator):
4643
Check passwords from a dictionary of usernames mapped to passwords.
4744
"""
4845

49-
configuration_schema = """
50-
$schema": http://json-schema.org/draft-07/schema#
51-
type: object
52-
additionalProperties: false
53-
properties:
54-
users_to_password:
55-
type: object
56-
description: |
57-
Mapping usernames to password. Environment variable expansion should be
58-
used to avoid placing passwords directly in configuration.
59-
confirmation_message:
60-
type: string
61-
description: May be displayed by client after successful login.
62-
"""
63-
64-
def __init__(
65-
self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""
66-
):
67-
self._users_to_passwords = users_to_passwords
68-
self.confirmation_message = confirmation_message
46+
users_to_passwords: Mapping[str, str]
6947

7048
async def authenticate(
7149
self, username: str, password: str
7250
) -> Optional[UserSessionState]:
73-
true_password = self._users_to_passwords.get(username)
51+
true_password = self.users_to_passwords.get(username)
7452
if not true_password:
7553
# Username is not valid.
7654
return
@@ -79,26 +57,13 @@ async def authenticate(
7957

8058

8159
class PAMAuthenticator(InternalAuthenticator):
82-
configuration_schema = """
83-
$schema": http://json-schema.org/draft-07/schema#
84-
type: object
85-
additionalProperties: false
86-
properties:
87-
service:
88-
type: string
89-
description: PAM service. Default is 'login'.
90-
confirmation_message:
91-
type: string
92-
description: May be displayed by client after successful login.
93-
"""
94-
95-
def __init__(self, service: str = "login", confirmation_message: str = ""):
60+
service: str = "login"
61+
62+
def model_post_init(self, __context: Any):
9663
if not modules_available("pamela"):
9764
raise ModuleNotFoundError(
9865
"This PAMAuthenticator requires the module 'pamela' to be installed."
9966
)
100-
self.service = service
101-
self.confirmation_message = confirmation_message
10267
# TODO Try to open a PAM session.
10368

10469
async def authenticate(
@@ -115,47 +80,17 @@ async def authenticate(
11580

11681

11782
class OIDCAuthenticator(ExternalAuthenticator):
118-
configuration_schema = """
119-
$schema": http://json-schema.org/draft-07/schema#
120-
type: object
121-
additionalProperties: false
122-
properties:
123-
audience:
124-
type: string
125-
client_id:
126-
type: string
127-
client_secret:
128-
type: string
129-
well_known_uri:
130-
type: string
131-
confirmation_message:
132-
type: string
133-
"""
134-
135-
def __init__(
136-
self,
137-
audience: str,
138-
client_id: str,
139-
client_secret: str,
140-
well_known_uri: str,
141-
confirmation_message: str = "",
142-
):
143-
self._audience = audience
144-
self._client_id = client_id
145-
self._client_secret = Secret(client_secret)
146-
self._well_known_url = well_known_uri
147-
self.confirmation_message = confirmation_message
83+
audience: str
84+
client_id: str
85+
client_secret: Secret[str]
86+
well_known_uri: str
14887

14988
@functools.cached_property
15089
def _config_from_oidc_url(self) -> dict[str, Any]:
151-
response: httpx.Response = httpx.get(self._well_known_url)
90+
response: httpx.Response = httpx.get(self.well_known_url)
15291
response.raise_for_status()
15392
return response.json()
15493

155-
@functools.cached_property
156-
def client_id(self) -> str:
157-
return self._client_id
158-
15994
@functools.cached_property
16095
def id_token_signing_alg_values_supported(self) -> list[str]:
16196
return cast(
@@ -190,8 +125,8 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
190125
response = await exchange_code(
191126
self.token_endpoint,
192127
code,
193-
self._client_id,
194-
self._client_secret.get_secret_value(),
128+
self.client_id,
129+
self.client_secret.get_secret_value(),
195130
redirect_uri,
196131
)
197132
response_body = response.json()
@@ -207,7 +142,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
207142
token=id_token,
208143
key=keys,
209144
algorithms=self.id_token_signing_alg_values_supported,
210-
audience=self._audience,
145+
audience=self.audience,
211146
access_token=access_token,
212147
)
213148
except JWTError:
@@ -338,6 +273,18 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st
338273
return rv
339274

340275

276+
T = TypeVar("T")
277+
278+
279+
def one_or_many(value: Union[T, list[T]]) -> list[T]:
280+
if isinstance(value, str) or not isinstance(value, Iterable):
281+
return [value]
282+
return list(value)
283+
284+
285+
OneOrMany = BeforeValidator(one_or_many)
286+
287+
341288
class LDAPAuthenticator(InternalAuthenticator):
342289
"""
343290
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
@@ -396,7 +343,7 @@ class LDAPAuthenticator(InternalAuthenticator):
396343
"uid={username},ou=people,dc=wikimedia,dc=org",
397344
"uid={username},ou=Developers,dc=wikimedia,dc=org"
398345
]
399-
allowed_groups: list or None
346+
allowed_groups: list
400347
List of LDAP group DNs that users could be members of to be granted access.
401348
402349
If a user is in any one of the listed groups, then that user is granted access.
@@ -493,7 +440,10 @@ class LDAPAuthenticator(InternalAuthenticator):
493440
494441
from bluesky_httpserver.authenticators import LDAPAuthenticator
495442
authenticator = LDAPAuthenticator(
496-
"localhost", 1389, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=False
443+
server_address="localhost",
444+
server_port=1389,
445+
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
446+
use_tls=False
497447
)
498448
await authenticator.authenticate("user01", "password1")
499449
await authenticator.authenticate("user02", "password2")
@@ -521,79 +471,35 @@ class LDAPAuthenticator(InternalAuthenticator):
521471
id: user02
522472
"""
523473

524-
def __init__(
525-
self,
526-
server_address,
527-
server_port=None,
528-
*,
529-
use_ssl=False,
530-
use_tls=True,
531-
connect_timeout=5,
532-
receive_timeout=60,
533-
bind_dn_template=None,
534-
allowed_groups=None,
535-
valid_username_regex=r"^[a-z][.a-z0-9_-]*$",
536-
lookup_dn=False,
537-
user_search_base=None,
538-
user_attribute=None,
539-
lookup_dn_search_filter="({login_attr}={login})",
540-
lookup_dn_search_user=None,
541-
lookup_dn_search_password=None,
542-
lookup_dn_user_dn_attribute=None,
543-
escape_userdn=False,
544-
search_filter="",
545-
attributes=None,
546-
auth_state_attributes=None,
547-
use_lookup_dn_username=True,
548-
confirmation_message="",
549-
):
550-
self.use_ssl = use_ssl
551-
self.use_tls = use_tls
552-
self.connect_timeout = connect_timeout
553-
self.receive_timeout = receive_timeout
554-
self.bind_dn_template = bind_dn_template
555-
self.allowed_groups = allowed_groups
556-
self.valid_username_regex = valid_username_regex
557-
self.lookup_dn = lookup_dn
558-
self.user_search_base = user_search_base
559-
self.user_attribute = user_attribute
560-
self.lookup_dn_search_filter = lookup_dn_search_filter
561-
self.lookup_dn_search_user = lookup_dn_search_user
562-
self.lookup_dn_search_password = lookup_dn_search_password
563-
self.lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute
564-
self.escape_userdn = escape_userdn
565-
self.search_filter = search_filter
566-
self.attributes = attributes if attributes else []
567-
self.auth_state_attributes = (
568-
auth_state_attributes if auth_state_attributes else []
569-
)
570-
self.use_lookup_dn_username = use_lookup_dn_username
571-
572-
if isinstance(server_address, str):
573-
server_address_list = [server_address]
574-
elif isinstance(server_address, Iterable):
575-
server_address_list = list(server_address)
576-
else:
577-
raise TypeError(
578-
f"Unsupported type of `server_address` (list): server_address={server_address} "
579-
f"type(server_address)={type(server_address)}"
580-
)
581-
if not server_address_list:
582-
raise ValueError(
583-
"No servers are specified: 'server_address' is an empty list"
584-
)
585-
586-
self.server_address_list = server_address_list
587-
self.server_port = (
588-
server_port if server_port is not None else self._server_port_default()
589-
)
590-
self.confirmation_message = confirmation_message
591-
592-
def _server_port_default(self):
593-
if self.use_ssl:
594-
return 636 # default SSL port for LDAP
595-
else:
596-
return 389 # default plaintext port for LDAP
474+
model_config = ConfigDict(use_attribute_docstrings=True)
475+
476+
server_address_list: Annotated[list[str], OneOrMany, Field(alias="server_address")]
477+
server_port: Annotated[
478+
int,
479+
Field(
480+
default_factory=lambda data: port
481+
if (port := data["server_port"] is not None)
482+
else (636 if data["use_ssl"] else 389)
483+
),
484+
]
485+
use_ssl: bool = False
486+
use_tls: bool = True
487+
connect_timeout: float = 5.0
488+
receive_timeout: float = 60.0
489+
bind_dn_template: Annotated[list[str], OneOrMany] = []
490+
allowed_groups: list[str] = []
491+
lookup_dn: bool = False
492+
user_search_base: Optional[str] = None
493+
user_attribute: Optional[str] = None
494+
lookup_dn_search_filter: Optional[str] = "({login_attr}={login})"
495+
lookup_dn_search_user: Optional[str] = None
496+
lookup_dn_search_password: Optional[str] = None
497+
lookup_dn_user_dn_attribute: Optional[str] = None
498+
escape_userdn: bool = False
499+
search_filter: str = ""
500+
attributes: list[str] = []
501+
auth_state_attributes: list[str] = []
502+
use_lookup_dn_username: bool = True
597503

598504
async def resolve_username(self, username_supplied_by_user):
599505
import ldap3

0 commit comments

Comments
 (0)