Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor storage methods to use keyword-only arguments for clarity #114

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 55 additions & 84 deletions aioauth/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,92 +10,25 @@
----
"""

import sys
from typing import TYPE_CHECKING, Optional, Generic
from typing import Optional, Generic

from .models import AuthorizationCode, Client, Token
from .types import CodeChallengeMethod, TokenType

from .requests import Request
from .types import UserType

if sys.version_info >= (3, 11):
from typing import NotRequired, Unpack
else:
from typing_extensions import NotRequired, Unpack

from typing import TypedDict as _TypedDict

# NOTE: workaround for generic TypedDict support
# https://github.com/python/cpython/issues/89026
if TYPE_CHECKING:

class TypedDict(Generic[UserType], _TypedDict): ...

else:

class TypedDict(Generic[UserType]): ...


class GetAuthorizationCodeArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
code: str


class GetClientArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
client_secret: NotRequired[Optional[str]]


class GetIdTokenArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
scope: str
response_type: Optional[str]
redirect_uri: str
nonce: Optional[str]


class CreateAuthorizationCodeArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
scope: str
response_type: str
redirect_uri: str
code_challenge_method: Optional[CodeChallengeMethod]
code_challenge: Optional[str]
code: str
nonce: NotRequired[Optional[str]]


class CreateTokenArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
scope: str
access_token: str
refresh_token: Optional[str]


class GetTokenArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
token_type: Optional[TokenType] # default is "refresh_token"
access_token: Optional[str] # default is None
refresh_token: Optional[str] # default is None


class RevokeTokenArgs(TypedDict[UserType]):
request: Request[UserType]
client_id: str
refresh_token: Optional[str]
token_type: Optional[TokenType]
access_token: Optional[str]


class TokenStorage(Generic[UserType]):
async def create_token(self, **kwargs: Unpack[CreateTokenArgs[UserType]]) -> Token:
async def create_token(
self,
*,
request: Request[UserType],
client_id: str,
scope: str,
access_token: str,
refresh_token: Optional[str] = None,
) -> Token:
"""Generates a user token and stores it in the database.

Used by:
Expand All @@ -120,7 +53,13 @@ async def create_token(self, **kwargs: Unpack[CreateTokenArgs[UserType]]) -> Tok
raise NotImplementedError("Method create_token must be implemented")

async def get_token(
self, **kwargs: Unpack[GetTokenArgs[UserType]]
self,
*,
request: Request[UserType],
client_id: str,
token_type: Optional[TokenType] = None,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> Optional[Token]:
"""Gets existing token from the database.

Expand All @@ -138,15 +77,32 @@ async def get_token(
"""
raise NotImplementedError("Method get_token must be implemented")

async def revoke_token(self, **kwargs: Unpack[RevokeTokenArgs[UserType]]) -> None:
async def revoke_token(
self,
*,
request: Request[UserType],
client_id: str,
refresh_token: Optional[str] = None,
token_type: Optional[TokenType] = None,
access_token: Optional[str] = None,
) -> None:
"""Revokes a token from the database."""
raise NotImplementedError


class AuthorizationCodeStorage(Generic[UserType]):
async def create_authorization_code(
self,
**kwargs: Unpack[CreateAuthorizationCodeArgs[UserType]],
*,
request: Request[UserType],
client_id: str,
scope: str,
response_type: str,
redirect_uri: str,
code: str,
code_challenge_method: Optional[CodeChallengeMethod] = None,
code_challenge: Optional[str] = None,
nonce: Optional[str] = None,
) -> AuthorizationCode:
"""Generates an authorization token and stores it in the database.

Expand All @@ -172,7 +128,10 @@ async def create_authorization_code(

async def get_authorization_code(
self,
**kwargs: Unpack[GetAuthorizationCodeArgs[UserType]],
*,
request: Request[UserType],
client_id: str,
code: str,
) -> Optional[AuthorizationCode]:
"""Gets existing authorization code from the database if it exists.

Expand All @@ -196,7 +155,10 @@ async def get_authorization_code(

async def delete_authorization_code(
self,
**kwargs: Unpack[GetAuthorizationCodeArgs[UserType]],
*,
request: Request[UserType],
client_id: str,
code: str,
) -> None:
"""Deletes authorization code from database.

Expand All @@ -216,7 +178,10 @@ async def delete_authorization_code(
class ClientStorage(Generic[UserType]):
async def get_client(
self,
**kwargs: Unpack[GetClientArgs[UserType]],
*,
request: Request[UserType],
client_id: str,
client_secret: Optional[str] = None,
) -> Optional[Client[UserType]]:
"""Gets existing client from the database if it exists.

Expand Down Expand Up @@ -256,7 +221,13 @@ async def get_user(self, request: Request[UserType]) -> Optional[UserType]:
class IDTokenStorage(Generic[UserType]):
async def get_id_token(
self,
**kwargs: Unpack[GetIdTokenArgs[UserType]],
*,
request: Request[UserType],
client_id: str,
scope: str,
redirect_uri: str,
response_type: Optional[str] = None,
nonce: Optional[str] = None,
) -> str:
"""Returns an id_token.
For more information see `OpenID Connect Core 1.0 incorporating errata set 1 section 2 <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_.
Expand Down
25 changes: 18 additions & 7 deletions examples/shared/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, session: AsyncSession):

async def get_client(
self,
*,
request: Request[User],
client_id: str,
client_secret: Optional[str] = None,
Expand Down Expand Up @@ -61,15 +62,16 @@ def __init__(self, session: AsyncSession):

async def create_authorization_code(
self,
*,
request: Request[User],
client_id: str,
scope: str,
response_type: str,
redirect_uri: str,
code_challenge_method: Optional[CodeChallengeMethod],
code_challenge: Optional[str],
code: str,
**kwargs,
code_challenge_method: Optional[CodeChallengeMethod] = None,
code_challenge: Optional[str] = None,
nonce: Optional[str] = None,
) -> AuthorizationCode:
""""""
auth_code = AuthorizationCode(
Expand All @@ -83,7 +85,6 @@ async def create_authorization_code(
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
user=request.user,
**kwargs,
)
record = AuthCodeTable(
code=auth_code.code,
Expand All @@ -104,7 +105,11 @@ async def create_authorization_code(
return auth_code

async def get_authorization_code(
self, request: Request[User], client_id: str, code: str
self,
*,
request: Request[User],
client_id: str,
code: str,
) -> Optional[AuthorizationCode]:
""" """
async with self.session:
Expand All @@ -125,7 +130,11 @@ async def get_authorization_code(
)

async def delete_authorization_code(
self, request: Request[User], client_id: str, code: str
self,
*,
request: Request[User],
client_id: str,
code: str,
) -> None:
""" """
async with self.session:
Expand All @@ -147,7 +156,7 @@ async def create_token(
client_id: str,
scope: str,
access_token: str,
refresh_token: Optional[str],
refresh_token: Optional[str] = None,
) -> Token:
""" """
token = Token(
Expand Down Expand Up @@ -179,6 +188,7 @@ async def create_token(

async def get_token(
self,
*,
request: Request[User],
client_id: str,
token_type: Optional[TokenType] = "refresh_token",
Expand Down Expand Up @@ -208,6 +218,7 @@ async def get_token(

async def revoke_token(
self,
*,
request: Request[User],
client_id: str,
token_type: Optional[TokenType] = "refresh_token",
Expand Down
Loading
Loading