Skip to content

Commit b35e3dc

Browse files
tpoliawdanielballan
authored andcommitted
Convert authenticators to pydantic models
1 parent 6b47006 commit b35e3dc

File tree

5 files changed

+114
-172
lines changed

5 files changed

+114
-172
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ tiled catalog upgrade-database [postgresql://.. | sqlite:///...]
120120
- Refactored internal Zarr version detection
121121
- For compatibility with older clients, do not require metadata updates to include
122122
an `access_blob` in the body of the request.
123+
- Convert authenticators into pydantic models
123124

124125
### Fixed
125126

tiled/_tests/test_authenticators.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@
2525
])
2626
# fmt: on
2727
@pytest.mark.parametrize("use_tls,use_ssl", [(False, False)])
28+
@pytest.mark.skipif(not TILED_TEST_LDAP, reason="Requires an LDAP container and TILED_TEST_LDAP to be set")
2829
def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server_port):
2930
"""
3031
Basic test for ``LDAPAuthenticator``.
3132
3233
TODO: The test could be extended with enabled TLS or SSL, but it requires configuration
3334
of the LDAP server.
3435
"""
35-
if not TILED_TEST_LDAP:
36-
pytest.skip("Run an LDAP container and set TILED_TEST_LDAP to run")
36+
3737
authenticator = LDAPAuthenticator(
38-
ldap_server_address,
39-
ldap_server_port,
38+
server_address=ldap_server_address,
4039
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
41-
use_tls=use_tls,
4240
use_ssl=use_ssl,
41+
use_tls=use_tls,
42+
server_port=ldap_server_port
4343
)
4444

4545
async def testing():
@@ -49,3 +49,19 @@ async def testing():
4949
assert (await authenticator.authenticate("user02", "password2a")) is None
5050

5151
asyncio.run(testing())
52+
53+
54+
def test_ldap_port_validation():
55+
# given port can be none but will be replaced with a default
56+
auth = LDAPAuthenticator(server_address="http://ldap.example.com", server_port=None)
57+
assert auth.server_port is not None
58+
59+
60+
def test_auth_server_list_wrapping():
61+
auth = LDAPAuthenticator(server_address="http://ldap.example.com", server_port=None)
62+
assert auth.server_address_list == ["http://ldap.example.com"]
63+
64+
65+
def test_list_of_addresses_not_nested_into_extra_list():
66+
auth = LDAPAuthenticator(server_address=["http://ldap.example.com"])
67+
assert auth.server_address_list == ["http://ldap.example.com"]

tiled/authenticators.py

Lines changed: 72 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import re
66
import secrets
77
from collections.abc import Iterable
8-
from typing import Any, Mapping, Optional, cast
8+
from typing import Annotated, Any, Mapping, Optional, TypeVar, cast
99

1010
import httpx
1111
from fastapi import APIRouter, Request
1212
from jose import JWTError, jwt
13-
from pydantic import Secret
13+
from pydantic import BeforeValidator, ConfigDict, Field, Secret, model_validator
1414
from starlette.responses import RedirectResponse
1515

1616
from .server.protocols import (
@@ -32,9 +32,6 @@ class DummyAuthenticator(InternalAuthenticator):
3232
3333
"""
3434

35-
def __init__(self, confirmation_message: str = ""):
36-
self.confirmation_message = confirmation_message
37-
3835
async def authenticate(self, username: str, password: str) -> UserSessionState:
3936
return UserSessionState(username, {})
4037

@@ -46,31 +43,12 @@ class DictionaryAuthenticator(InternalAuthenticator):
4643
Check passwords from a dictionary of usernames mapped to passwords.
4744
"""
4845

49-
configuration_schema = """
50-
$schema": http://json-schema.org/draft-07/schema#
51-
type: object
52-
additionalProperties: false
53-
properties:
54-
users_to_password:
55-
type: object
56-
description: |
57-
Mapping usernames to password. Environment variable expansion should be
58-
used to avoid placing passwords directly in configuration.
59-
confirmation_message:
60-
type: string
61-
description: May be displayed by client after successful login.
62-
"""
63-
64-
def __init__(
65-
self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""
66-
):
67-
self._users_to_passwords = users_to_passwords
68-
self.confirmation_message = confirmation_message
46+
users_to_passwords: Mapping[str, str]
6947

7048
async def authenticate(
7149
self, username: str, password: str
7250
) -> Optional[UserSessionState]:
73-
true_password = self._users_to_passwords.get(username)
51+
true_password = self.users_to_passwords.get(username)
7452
if not true_password:
7553
# Username is not valid.
7654
return
@@ -79,26 +57,13 @@ async def authenticate(
7957

8058

8159
class PAMAuthenticator(InternalAuthenticator):
82-
configuration_schema = """
83-
$schema": http://json-schema.org/draft-07/schema#
84-
type: object
85-
additionalProperties: false
86-
properties:
87-
service:
88-
type: string
89-
description: PAM service. Default is 'login'.
90-
confirmation_message:
91-
type: string
92-
description: May be displayed by client after successful login.
93-
"""
94-
95-
def __init__(self, service: str = "login", confirmation_message: str = ""):
60+
service: str = "login"
61+
62+
def model_post_init(self, __context: Any):
9663
if not modules_available("pamela"):
9764
raise ModuleNotFoundError(
9865
"This PAMAuthenticator requires the module 'pamela' to be installed."
9966
)
100-
self.service = service
101-
self.confirmation_message = confirmation_message
10267
# TODO Try to open a PAM session.
10368

10469
async def authenticate(
@@ -115,47 +80,17 @@ async def authenticate(
11580

11681

11782
class OIDCAuthenticator(ExternalAuthenticator):
118-
configuration_schema = """
119-
$schema": http://json-schema.org/draft-07/schema#
120-
type: object
121-
additionalProperties: false
122-
properties:
123-
audience:
124-
type: string
125-
client_id:
126-
type: string
127-
client_secret:
128-
type: string
129-
well_known_uri:
130-
type: string
131-
confirmation_message:
132-
type: string
133-
"""
134-
135-
def __init__(
136-
self,
137-
audience: str,
138-
client_id: str,
139-
client_secret: str,
140-
well_known_uri: str,
141-
confirmation_message: str = "",
142-
):
143-
self._audience = audience
144-
self._client_id = client_id
145-
self._client_secret = Secret(client_secret)
146-
self._well_known_url = well_known_uri
147-
self.confirmation_message = confirmation_message
83+
audience: str
84+
client_id: str
85+
client_secret: Secret[str]
86+
well_known_uri: str
14887

14988
@functools.cached_property
15089
def _config_from_oidc_url(self) -> dict[str, Any]:
151-
response: httpx.Response = httpx.get(self._well_known_url)
90+
response: httpx.Response = httpx.get(self.well_known_url)
15291
response.raise_for_status()
15392
return response.json()
15493

155-
@functools.cached_property
156-
def client_id(self) -> str:
157-
return self._client_id
158-
15994
@functools.cached_property
16095
def id_token_signing_alg_values_supported(self) -> list[str]:
16196
return cast(
@@ -190,8 +125,8 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
190125
response = await exchange_code(
191126
self.token_endpoint,
192127
code,
193-
self._client_id,
194-
self._client_secret.get_secret_value(),
128+
self.client_id,
129+
self.client_secret.get_secret_value(),
195130
redirect_uri,
196131
)
197132
response_body = response.json()
@@ -207,7 +142,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
207142
token=id_token,
208143
key=keys,
209144
algorithms=self.id_token_signing_alg_values_supported,
210-
audience=self._audience,
145+
audience=self.audience,
211146
access_token=access_token,
212147
)
213148
except JWTError:
@@ -331,6 +266,18 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st
331266
return rv
332267

333268

269+
T = TypeVar("T")
270+
271+
272+
def one_or_many(value: Any) -> list[Any]:
273+
if isinstance(value, str) or not isinstance(value, Iterable):
274+
return [value]
275+
return list(value)
276+
277+
278+
OneOrMany = BeforeValidator(one_or_many)
279+
280+
334281
class LDAPAuthenticator(InternalAuthenticator):
335282
"""
336283
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
@@ -357,10 +304,10 @@ class LDAPAuthenticator(InternalAuthenticator):
357304
Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled
358305
in production systems.
359306
360-
connect_timeout: float
307+
connect_timeout: int
361308
Timeout used for connecting to the LDAP server. Default: 5.
362309
363-
receive_timeout: float
310+
receive_timeout: int
364311
Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for
365312
completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server.
366313
Default: 60.
@@ -389,7 +336,7 @@ class LDAPAuthenticator(InternalAuthenticator):
389336
"uid={username},ou=people,dc=wikimedia,dc=org",
390337
"uid={username},ou=Developers,dc=wikimedia,dc=org"
391338
]
392-
allowed_groups: list or None
339+
allowed_groups: list
393340
List of LDAP group DNs that users could be members of to be granted access.
394341
395342
If a user is in any one of the listed groups, then that user is granted access.
@@ -486,7 +433,10 @@ class LDAPAuthenticator(InternalAuthenticator):
486433
487434
from bluesky_httpserver.authenticators import LDAPAuthenticator
488435
authenticator = LDAPAuthenticator(
489-
"localhost", 1389, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=False
436+
server_address="localhost",
437+
server_port=1389,
438+
bind_dn_template="cn={username},ou=users,dc=example,dc=org",
439+
use_tls=False
490440
)
491441
await authenticator.authenticate("user01", "password1")
492442
await authenticator.authenticate("user02", "password2")
@@ -514,79 +464,44 @@ class LDAPAuthenticator(InternalAuthenticator):
514464
id: user02
515465
"""
516466

517-
def __init__(
518-
self,
519-
server_address,
520-
server_port=None,
521-
*,
522-
use_ssl=False,
523-
use_tls=True,
524-
connect_timeout=5,
525-
receive_timeout=60,
526-
bind_dn_template=None,
527-
allowed_groups=None,
528-
valid_username_regex=r"^[a-z][.a-z0-9_-]*$",
529-
lookup_dn=False,
530-
user_search_base=None,
531-
user_attribute=None,
532-
lookup_dn_search_filter="({login_attr}={login})",
533-
lookup_dn_search_user=None,
534-
lookup_dn_search_password=None,
535-
lookup_dn_user_dn_attribute=None,
536-
escape_userdn=False,
537-
search_filter="",
538-
attributes=None,
539-
auth_state_attributes=None,
540-
use_lookup_dn_username=True,
541-
confirmation_message="",
542-
):
543-
self.use_ssl = use_ssl
544-
self.use_tls = use_tls
545-
self.connect_timeout = connect_timeout
546-
self.receive_timeout = receive_timeout
547-
self.bind_dn_template = bind_dn_template
548-
self.allowed_groups = allowed_groups
549-
self.valid_username_regex = valid_username_regex
550-
self.lookup_dn = lookup_dn
551-
self.user_search_base = user_search_base
552-
self.user_attribute = user_attribute
553-
self.lookup_dn_search_filter = lookup_dn_search_filter
554-
self.lookup_dn_search_user = lookup_dn_search_user
555-
self.lookup_dn_search_password = lookup_dn_search_password
556-
self.lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute
557-
self.escape_userdn = escape_userdn
558-
self.search_filter = search_filter
559-
self.attributes = attributes if attributes else []
560-
self.auth_state_attributes = (
561-
auth_state_attributes if auth_state_attributes else []
562-
)
563-
self.use_lookup_dn_username = use_lookup_dn_username
564-
565-
if isinstance(server_address, str):
566-
server_address_list = [server_address]
567-
elif isinstance(server_address, Iterable):
568-
server_address_list = list(server_address)
569-
else:
570-
raise TypeError(
571-
f"Unsupported type of `server_address` (list): server_address={server_address} "
572-
f"type(server_address)={type(server_address)}"
573-
)
574-
if not server_address_list:
575-
raise ValueError(
576-
"No servers are specified: 'server_address' is an empty list"
577-
)
578-
579-
self.server_address_list = server_address_list
580-
self.server_port = (
581-
server_port if server_port is not None else self._server_port_default()
582-
)
583-
self.confirmation_message = confirmation_message
584-
585-
def _server_port_default(self):
586-
if self.use_ssl:
587-
return 636 # default SSL port for LDAP
588-
else:
589-
return 389 # default plaintext port for LDAP
467+
model_config = ConfigDict(use_attribute_docstrings=True)
468+
469+
server_address_list: Annotated[list[str], OneOrMany, Field(alias="server_address")]
470+
server_port: Annotated[
471+
int,
472+
Field(
473+
default_factory=lambda data: port
474+
if (port := data.get("server_port") is not None)
475+
else (636 if data.get("use_ssl") else 389)
476+
),
477+
]
478+
use_ssl: bool = False
479+
use_tls: bool = True
480+
connect_timeout: int = 5
481+
receive_timeout: int = 60
482+
bind_dn_template: Annotated[list[str], OneOrMany] = []
483+
allowed_groups: list[str] = []
484+
valid_username_regex: str = r"^[a-z][.a-z0-9_-]*$"
485+
lookup_dn: bool = False
486+
user_search_base: Optional[str] = None
487+
user_attribute: Optional[str] = None
488+
lookup_dn_search_filter: Optional[str] = "({login_attr}={login})"
489+
lookup_dn_search_user: Optional[str] = None
490+
lookup_dn_search_password: Optional[str] = None
491+
lookup_dn_user_dn_attribute: Optional[str] = None
492+
escape_userdn: bool = False
493+
search_filter: str = ""
494+
attributes: list[str] = []
495+
auth_state_attributes: list[str] = []
496+
use_lookup_dn_username: bool = True
497+
498+
@model_validator(mode="before")
499+
@classmethod
500+
def ignore_nones(cls, data: Any) -> Any:
501+
if isinstance(data, dict):
502+
if data.get("server_port") is None:
503+
data.pop("server_port", None)
504+
return data
590505

591506
async def resolve_username(self, username_supplied_by_user):
592507
import ldap3

0 commit comments

Comments
 (0)