55import re
66import secrets
77from collections .abc import Iterable
8- from typing import Any , Mapping , Optional , cast
8+ from typing import Annotated , Any , Mapping , Optional , TypeVar , cast
99
1010import httpx
1111from fastapi import APIRouter , Request
1212from jose import JWTError , jwt
13- from pydantic import Secret
13+ from pydantic import BeforeValidator , ConfigDict , Field , Secret , model_validator
1414from starlette .responses import RedirectResponse
1515
1616from .server .protocols import (
@@ -32,9 +32,6 @@ class DummyAuthenticator(InternalAuthenticator):
3232
3333 """
3434
35- def __init__ (self , confirmation_message : str = "" ):
36- self .confirmation_message = confirmation_message
37-
3835 async def authenticate (self , username : str , password : str ) -> UserSessionState :
3936 return UserSessionState (username , {})
4037
@@ -46,31 +43,12 @@ class DictionaryAuthenticator(InternalAuthenticator):
4643 Check passwords from a dictionary of usernames mapped to passwords.
4744 """
4845
49- configuration_schema = """
50- $schema": http://json-schema.org/draft-07/schema#
51- type: object
52- additionalProperties: false
53- properties:
54- users_to_password:
55- type: object
56- description: |
57- Mapping usernames to password. Environment variable expansion should be
58- used to avoid placing passwords directly in configuration.
59- confirmation_message:
60- type: string
61- description: May be displayed by client after successful login.
62- """
63-
64- def __init__ (
65- self , users_to_passwords : Mapping [str , str ], confirmation_message : str = ""
66- ):
67- self ._users_to_passwords = users_to_passwords
68- self .confirmation_message = confirmation_message
46+ users_to_passwords : Mapping [str , str ]
6947
7048 async def authenticate (
7149 self , username : str , password : str
7250 ) -> Optional [UserSessionState ]:
73- true_password = self ._users_to_passwords .get (username )
51+ true_password = self .users_to_passwords .get (username )
7452 if not true_password :
7553 # Username is not valid.
7654 return
@@ -79,26 +57,13 @@ async def authenticate(
7957
8058
8159class PAMAuthenticator (InternalAuthenticator ):
82- configuration_schema = """
83- $schema": http://json-schema.org/draft-07/schema#
84- type: object
85- additionalProperties: false
86- properties:
87- service:
88- type: string
89- description: PAM service. Default is 'login'.
90- confirmation_message:
91- type: string
92- description: May be displayed by client after successful login.
93- """
94-
95- def __init__ (self , service : str = "login" , confirmation_message : str = "" ):
60+ service : str = "login"
61+
62+ def model_post_init (self , __context : Any ):
9663 if not modules_available ("pamela" ):
9764 raise ModuleNotFoundError (
9865 "This PAMAuthenticator requires the module 'pamela' to be installed."
9966 )
100- self .service = service
101- self .confirmation_message = confirmation_message
10267 # TODO Try to open a PAM session.
10368
10469 async def authenticate (
@@ -115,47 +80,17 @@ async def authenticate(
11580
11681
11782class OIDCAuthenticator (ExternalAuthenticator ):
118- configuration_schema = """
119- $schema": http://json-schema.org/draft-07/schema#
120- type: object
121- additionalProperties: false
122- properties:
123- audience:
124- type: string
125- client_id:
126- type: string
127- client_secret:
128- type: string
129- well_known_uri:
130- type: string
131- confirmation_message:
132- type: string
133- """
134-
135- def __init__ (
136- self ,
137- audience : str ,
138- client_id : str ,
139- client_secret : str ,
140- well_known_uri : str ,
141- confirmation_message : str = "" ,
142- ):
143- self ._audience = audience
144- self ._client_id = client_id
145- self ._client_secret = Secret (client_secret )
146- self ._well_known_url = well_known_uri
147- self .confirmation_message = confirmation_message
83+ audience : str
84+ client_id : str
85+ client_secret : Secret [str ]
86+ well_known_uri : str
14887
14988 @functools .cached_property
15089 def _config_from_oidc_url (self ) -> dict [str , Any ]:
151- response : httpx .Response = httpx .get (self ._well_known_url )
90+ response : httpx .Response = httpx .get (self .well_known_url )
15291 response .raise_for_status ()
15392 return response .json ()
15493
155- @functools .cached_property
156- def client_id (self ) -> str :
157- return self ._client_id
158-
15994 @functools .cached_property
16095 def id_token_signing_alg_values_supported (self ) -> list [str ]:
16196 return cast (
@@ -190,8 +125,8 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
190125 response = await exchange_code (
191126 self .token_endpoint ,
192127 code ,
193- self ._client_id ,
194- self ._client_secret .get_secret_value (),
128+ self .client_id ,
129+ self .client_secret .get_secret_value (),
195130 redirect_uri ,
196131 )
197132 response_body = response .json ()
@@ -207,7 +142,7 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
207142 token = id_token ,
208143 key = keys ,
209144 algorithms = self .id_token_signing_alg_values_supported ,
210- audience = self ._audience ,
145+ audience = self .audience ,
211146 access_token = access_token ,
212147 )
213148 except JWTError :
@@ -331,6 +266,18 @@ async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, st
331266 return rv
332267
333268
269+ T = TypeVar ("T" )
270+
271+
272+ def one_or_many (value : Any ) -> list [Any ]:
273+ if isinstance (value , str ) or not isinstance (value , Iterable ):
274+ return [value ]
275+ return list (value )
276+
277+
278+ OneOrMany = BeforeValidator (one_or_many )
279+
280+
334281class LDAPAuthenticator (InternalAuthenticator ):
335282 """
336283 The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
@@ -357,10 +304,10 @@ class LDAPAuthenticator(InternalAuthenticator):
357304 Enable/disable TLS if ``use_ssl`` is False. By default TLS is enabled. It should not be disabled
358305 in production systems.
359306
360- connect_timeout: float
307+ connect_timeout: int
361308 Timeout used for connecting to the LDAP server. Default: 5.
362309
363- receive_timeout: float
310+ receive_timeout: int
364311 Timeout used for communication with the LDAP server, e.g. this timeout is used to wait for
365312 completion of 2FA. For smooth operation it should probably exceed timeout set at LDAP server.
366313 Default: 60.
@@ -389,7 +336,7 @@ class LDAPAuthenticator(InternalAuthenticator):
389336 "uid={username},ou=people,dc=wikimedia,dc=org",
390337 "uid={username},ou=Developers,dc=wikimedia,dc=org"
391338 ]
392- allowed_groups: list or None
339+ allowed_groups: list
393340 List of LDAP group DNs that users could be members of to be granted access.
394341
395342 If a user is in any one of the listed groups, then that user is granted access.
@@ -486,7 +433,10 @@ class LDAPAuthenticator(InternalAuthenticator):
486433
487434 from bluesky_httpserver.authenticators import LDAPAuthenticator
488435 authenticator = LDAPAuthenticator(
489- "localhost", 1389, bind_dn_template="cn={username},ou=users,dc=example,dc=org", use_tls=False
436+ server_address="localhost",
437+ server_port=1389,
438+ bind_dn_template="cn={username},ou=users,dc=example,dc=org",
439+ use_tls=False
490440 )
491441 await authenticator.authenticate("user01", "password1")
492442 await authenticator.authenticate("user02", "password2")
@@ -514,79 +464,44 @@ class LDAPAuthenticator(InternalAuthenticator):
514464 id: user02
515465 """
516466
517- def __init__ (
518- self ,
519- server_address ,
520- server_port = None ,
521- * ,
522- use_ssl = False ,
523- use_tls = True ,
524- connect_timeout = 5 ,
525- receive_timeout = 60 ,
526- bind_dn_template = None ,
527- allowed_groups = None ,
528- valid_username_regex = r"^[a-z][.a-z0-9_-]*$" ,
529- lookup_dn = False ,
530- user_search_base = None ,
531- user_attribute = None ,
532- lookup_dn_search_filter = "({login_attr}={login})" ,
533- lookup_dn_search_user = None ,
534- lookup_dn_search_password = None ,
535- lookup_dn_user_dn_attribute = None ,
536- escape_userdn = False ,
537- search_filter = "" ,
538- attributes = None ,
539- auth_state_attributes = None ,
540- use_lookup_dn_username = True ,
541- confirmation_message = "" ,
542- ):
543- self .use_ssl = use_ssl
544- self .use_tls = use_tls
545- self .connect_timeout = connect_timeout
546- self .receive_timeout = receive_timeout
547- self .bind_dn_template = bind_dn_template
548- self .allowed_groups = allowed_groups
549- self .valid_username_regex = valid_username_regex
550- self .lookup_dn = lookup_dn
551- self .user_search_base = user_search_base
552- self .user_attribute = user_attribute
553- self .lookup_dn_search_filter = lookup_dn_search_filter
554- self .lookup_dn_search_user = lookup_dn_search_user
555- self .lookup_dn_search_password = lookup_dn_search_password
556- self .lookup_dn_user_dn_attribute = lookup_dn_user_dn_attribute
557- self .escape_userdn = escape_userdn
558- self .search_filter = search_filter
559- self .attributes = attributes if attributes else []
560- self .auth_state_attributes = (
561- auth_state_attributes if auth_state_attributes else []
562- )
563- self .use_lookup_dn_username = use_lookup_dn_username
564-
565- if isinstance (server_address , str ):
566- server_address_list = [server_address ]
567- elif isinstance (server_address , Iterable ):
568- server_address_list = list (server_address )
569- else :
570- raise TypeError (
571- f"Unsupported type of `server_address` (list): server_address={ server_address } "
572- f"type(server_address)={ type (server_address )} "
573- )
574- if not server_address_list :
575- raise ValueError (
576- "No servers are specified: 'server_address' is an empty list"
577- )
578-
579- self .server_address_list = server_address_list
580- self .server_port = (
581- server_port if server_port is not None else self ._server_port_default ()
582- )
583- self .confirmation_message = confirmation_message
584-
585- def _server_port_default (self ):
586- if self .use_ssl :
587- return 636 # default SSL port for LDAP
588- else :
589- return 389 # default plaintext port for LDAP
467+ model_config = ConfigDict (use_attribute_docstrings = True )
468+
469+ server_address_list : Annotated [list [str ], OneOrMany , Field (alias = "server_address" )]
470+ server_port : Annotated [
471+ int ,
472+ Field (
473+ default_factory = lambda data : port
474+ if (port := data .get ("server_port" ) is not None )
475+ else (636 if data .get ("use_ssl" ) else 389 )
476+ ),
477+ ]
478+ use_ssl : bool = False
479+ use_tls : bool = True
480+ connect_timeout : int = 5
481+ receive_timeout : int = 60
482+ bind_dn_template : Annotated [list [str ], OneOrMany ] = []
483+ allowed_groups : list [str ] = []
484+ valid_username_regex : str = r"^[a-z][.a-z0-9_-]*$"
485+ lookup_dn : bool = False
486+ user_search_base : Optional [str ] = None
487+ user_attribute : Optional [str ] = None
488+ lookup_dn_search_filter : Optional [str ] = "({login_attr}={login})"
489+ lookup_dn_search_user : Optional [str ] = None
490+ lookup_dn_search_password : Optional [str ] = None
491+ lookup_dn_user_dn_attribute : Optional [str ] = None
492+ escape_userdn : bool = False
493+ search_filter : str = ""
494+ attributes : list [str ] = []
495+ auth_state_attributes : list [str ] = []
496+ use_lookup_dn_username : bool = True
497+
498+ @model_validator (mode = "before" )
499+ @classmethod
500+ def ignore_nones (cls , data : Any ) -> Any :
501+ if isinstance (data , dict ):
502+ if data .get ("server_port" ) is None :
503+ data .pop ("server_port" , None )
504+ return data
590505
591506 async def resolve_username (self , username_supplied_by_user ):
592507 import ldap3
0 commit comments