Skip to content

Commit 0c1688f

Browse files
authored
fix: mypy errors on custom models (#66)
* fix: mypy errors on custom models see the issue #63 while inheriting from models (Token, Client, AuthorizationCode), and using them in aioauth, mypy throws an error like: ``` error: Argument 1 of "create_authorization_code" is incompatible with supertype "BaseStorage"; supertype defines the argument type as "Request" ``` this PR fixes the above bug by adding additional `TToken`, `TClient` and `TAuthorizationCode` parameters to the `BaseModel` generic. usage example: ```python from dataclasses import dataclass from aioauth_fastapi.router import get_oauth2_router from aioauth.storage import BaseStorage from aioauth.requests import BaseRequest from aioauth.models import AuthorizationCode, Client, Token from aioauth.config import Settings from aioauth.server import AuthorizationServer from fastapi import FastAPI app = FastAPI() @DataClass class User: """Custom user model""" first_name: str last_name: str @DataClass class Request(BaseRequest[Query, Post, User]): """"Custom request""" class Storage(BaseStorage[Token, Client, AuthorizationCode, Request]): """ Storage methods must be implemented here. """ storage = Storage() authorization_server = AuthorizationServer[Request, Storage](storage) # NOTE: Redefinition of the default aioauth settings # INSECURE_TRANSPORT must be enabled for local development only! settings = Settings( INSECURE_TRANSPORT=True, ) # Include FastAPI router with oauth2 endpoints. app.include_router( get_oauth2_router(authorization_server, settings), prefix="/oauth2", tags=["oauth2"], ) ``` * enums were replaced to literals
1 parent da68dbf commit 0c1688f

21 files changed

+509
-448
lines changed

aioauth/collections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class HTTPHeaderDict(UserDict):
3333
d['hElLo'] == 'world' # >>> True
3434
"""
3535

36-
def __setitem__(self, key, value):
36+
def __setitem__(self, key: str, value: str):
3737
super().__setitem__(key.lower(), value)
3838

39-
def __getitem__(self, key):
39+
def __getitem__(self, key: str):
4040
return super().__getitem__(key.lower())

aioauth/errors.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99
"""
1010

1111
from http import HTTPStatus
12-
from typing import Optional
12+
from typing import Generic, Optional
1313
from urllib.parse import urljoin
14+
from typing_extensions import Literal
1415

1516
from .collections import HTTPHeaderDict
1617
from .constances import default_headers
17-
from .requests import Request
18+
from .requests import TRequest
1819
from .types import ErrorType
1920

2021

21-
class OAuth2Error(Exception):
22+
class OAuth2Error(Exception, Generic[TRequest]):
2223
"""Base exception that all other exceptions inherit from."""
2324

2425
error: ErrorType
@@ -29,7 +30,7 @@ class OAuth2Error(Exception):
2930

3031
def __init__(
3132
self,
32-
request: Request,
33+
request: TRequest,
3334
description: Optional[str] = None,
3435
headers: Optional[HTTPHeaderDict] = None,
3536
):
@@ -47,28 +48,28 @@ def __init__(
4748
super().__init__(f"({self.error}) {self.description}")
4849

4950

50-
class MethodNotAllowedError(OAuth2Error):
51+
class MethodNotAllowedError(OAuth2Error[TRequest]):
5152
"""
5253
The request is valid, but the method trying to be accessed is not
5354
available to the resource owner.
5455
"""
5556

5657
description = "HTTP method is not allowed."
5758
status_code: HTTPStatus = HTTPStatus.METHOD_NOT_ALLOWED
58-
error = ErrorType.METHOD_IS_NOT_ALLOWED
59+
error: Literal["method_is_not_allowed"] = "method_is_not_allowed"
5960

6061

61-
class InvalidRequestError(OAuth2Error):
62+
class InvalidRequestError(OAuth2Error[TRequest]):
6263
"""
6364
The request is missing a required parameter, includes an invalid
6465
parameter value, includes a parameter more than once, or is
6566
otherwise malformed.
6667
"""
6768

68-
error = ErrorType.INVALID_REQUEST
69+
error: Literal["invalid_request"] = "invalid_request"
6970

7071

71-
class InvalidClientError(OAuth2Error):
72+
class InvalidClientError(OAuth2Error[TRequest]):
7273
"""
7374
Client authentication failed (e.g. unknown client, no client
7475
authentication included, or unsupported authentication method).
@@ -81,36 +82,36 @@ class InvalidClientError(OAuth2Error):
8182
client.
8283
"""
8384

84-
error = ErrorType.INVALID_CLIENT
85+
error: Literal["invalid_client"] = "invalid_client"
8586
status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED
8687

8788

88-
class InsecureTransportError(OAuth2Error):
89+
class InsecureTransportError(OAuth2Error[TRequest]):
8990
"""An exception will be thrown if the current request is not secure."""
9091

9192
description = "OAuth 2 MUST utilize https."
92-
error = ErrorType.INSECURE_TRANSPORT
93+
error: Literal["insecure_transport"] = "insecure_transport"
9394

9495

95-
class UnsupportedGrantTypeError(OAuth2Error):
96+
class UnsupportedGrantTypeError(OAuth2Error[TRequest]):
9697
"""
9798
The authorization grant type is not supported by the authorization
9899
server.
99100
"""
100101

101-
error = ErrorType.UNSUPPORTED_GRANT_TYPE
102+
error: Literal["unsupported_grant_type"] = "unsupported_grant_type"
102103

103104

104-
class UnsupportedResponseTypeError(OAuth2Error):
105+
class UnsupportedResponseTypeError(OAuth2Error[TRequest]):
105106
"""
106107
The authorization server does not support obtaining an authorization
107108
code using this method.
108109
"""
109110

110-
error = ErrorType.UNSUPPORTED_RESPONSE_TYPE
111+
error: Literal["unsupported_response_type"] = "unsupported_response_type"
111112

112113

113-
class InvalidGrantError(OAuth2Error):
114+
class InvalidGrantError(OAuth2Error[TRequest]):
114115
"""
115116
The provided authorization grant (e.g. authorization code, resource
116117
owner credentials) or refresh token is invalid, expired, revoked, does
@@ -120,53 +121,53 @@ class InvalidGrantError(OAuth2Error):
120121
See `RFC6749 section 5.2 <https://tools.ietf.org/html/rfc6749#section-5.2>`_.
121122
"""
122123

123-
error = ErrorType.INVALID_GRANT
124+
error: Literal["invalid_grant"] = "invalid_grant"
124125

125126

126-
class MismatchingStateError(OAuth2Error):
127+
class MismatchingStateError(OAuth2Error[TRequest]):
127128
"""Unable to securely verify the integrity of the request and response."""
128129

129130
description = "CSRF Warning! State not equal in request and response."
130-
error = ErrorType.MISMATCHING_STATE
131+
error: Literal["mismatching_state"] = "mismatching_state"
131132

132133

133-
class UnauthorizedClientError(OAuth2Error):
134+
class UnauthorizedClientError(OAuth2Error[TRequest]):
134135
"""
135136
The authenticated client is not authorized to use this authorization
136137
grant type.
137138
"""
138139

139-
error = ErrorType.UNAUTHORIZED_CLIENT
140+
error: Literal["unauthorized_client"] = "unauthorized_client"
140141

141142

142-
class InvalidScopeError(OAuth2Error):
143+
class InvalidScopeError(OAuth2Error[TRequest]):
143144
"""
144145
The requested scope is invalid, unknown, or malformed, or
145146
exceeds the scope granted by the resource owner.
146147
147148
See `RFC6749 section 5.2 <https://tools.ietf.org/html/rfc6749#section-5.2>`_.
148149
"""
149150

150-
error = ErrorType.INVALID_SCOPE
151+
error: Literal["invalid_scope"] = "invalid_scope"
151152

152153

153-
class ServerError(OAuth2Error):
154+
class ServerError(OAuth2Error[TRequest]):
154155
"""
155156
The authorization server encountered an unexpected condition that
156157
prevented it from fulfilling the request. (This error code is needed
157158
because a ``HTTP 500`` (Internal Server Error) status code cannot be returned
158159
to the client via a HTTP redirect.)
159160
"""
160161

161-
error = ErrorType.SERVER_ERROR
162+
error: Literal["temporarily_unavailable"] = "temporarily_unavailable"
162163

163164

164-
class TemporarilyUnavailableError(OAuth2Error):
165+
class TemporarilyUnavailableError(OAuth2Error[TRequest]):
165166
"""
166167
The authorization server is currently unable to handle the request
167168
due to a temporary overloading or maintenance of the server.
168169
(This error code is needed because a ``HTTP 503`` (Service Unavailable)
169170
status code cannot be returned to the client via a HTTP redirect.)
170171
"""
171172

172-
error = ErrorType.TEMPORARILY_UNAVAILABLE
173+
error: Literal["temporarily_unavailable"] = "temporarily_unavailable"

aioauth/grant_type.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
----
99
"""
10+
from typing import Generic
1011
from .errors import (
1112
InvalidGrantError,
1213
InvalidRequestError,
@@ -15,22 +16,22 @@
1516
UnauthorizedClientError,
1617
)
1718
from .models import Client
18-
from .requests import Request
19+
from .requests import TRequest
1920
from .responses import TokenResponse
20-
from .storage import BaseStorage
21+
from .storage import TStorage
2122
from .utils import enforce_list, enforce_str, generate_token
2223

2324

24-
class GrantTypeBase:
25+
class GrantTypeBase(Generic[TRequest, TStorage]):
2526
"""Base grant type that all other grant types inherit from."""
2627

27-
def __init__(self, storage: BaseStorage, client_id: str, client_secret: str):
28+
def __init__(self, storage: TStorage, client_id: str, client_secret: str):
2829
self.storage = storage
2930
self.client_id = client_id
3031
self.client_secret = client_secret
3132

3233
async def create_token_response(
33-
self, request: Request, client: Client
34+
self, request: TRequest, client: Client
3435
) -> TokenResponse:
3536
"""Creates token response to reply to client."""
3637
token = await self.storage.create_token(
@@ -50,27 +51,27 @@ async def create_token_response(
5051
token_type=token.token_type,
5152
)
5253

53-
async def validate_request(self, request: Request) -> Client:
54+
async def validate_request(self, request: TRequest) -> Client:
5455
"""Validates the client request to ensure it is valid."""
5556
client = await self.storage.get_client(
5657
request, client_id=self.client_id, client_secret=self.client_secret
5758
)
5859

5960
if not client:
60-
raise InvalidRequestError(
61+
raise InvalidRequestError[TRequest](
6162
request=request, description="Invalid client_id parameter value."
6263
)
6364

6465
if not client.check_grant_type(request.post.grant_type):
65-
raise UnauthorizedClientError(request=request)
66+
raise UnauthorizedClientError[TRequest](request=request)
6667

6768
if not client.check_scope(request.post.scope):
68-
raise InvalidScopeError(request=request)
69+
raise InvalidScopeError[TRequest](request=request)
6970

7071
return client
7172

7273

73-
class AuthorizationCodeGrantType(GrantTypeBase):
74+
class AuthorizationCodeGrantType(GrantTypeBase[TRequest, TStorage]):
7475
"""
7576
The Authorization Code grant type is used by confidential and public
7677
clients to exchange an authorization code for an access token. After
@@ -86,21 +87,21 @@ class AuthorizationCodeGrantType(GrantTypeBase):
8687
See `RFC 6749 section 1.3.1 <https://tools.ietf.org/html/rfc6749#section-1.3.1>`_.
8788
"""
8889

89-
async def validate_request(self, request: Request) -> Client:
90+
async def validate_request(self, request: TRequest) -> Client:
9091
client = await super().validate_request(request)
9192

9293
if not request.post.redirect_uri:
93-
raise InvalidRequestError(
94+
raise InvalidRequestError[TRequest](
9495
request=request, description="Mismatching redirect URI."
9596
)
9697

9798
if not client.check_redirect_uri(request.post.redirect_uri):
98-
raise InvalidRequestError(
99+
raise InvalidRequestError[TRequest](
99100
request=request, description="Invalid redirect URI."
100101
)
101102

102103
if not request.post.code:
103-
raise InvalidRequestError(
104+
raise InvalidRequestError[TRequest](
104105
request=request, description="Missing code parameter."
105106
)
106107

@@ -109,30 +110,30 @@ async def validate_request(self, request: Request) -> Client:
109110
)
110111

111112
if not authorization_code:
112-
raise InvalidGrantError(request=request)
113+
raise InvalidGrantError[TRequest](request=request)
113114

114115
if (
115116
authorization_code.code_challenge
116117
and authorization_code.code_challenge_method
117118
):
118119
if not request.post.code_verifier:
119-
raise InvalidRequestError(
120+
raise InvalidRequestError[TRequest](
120121
request=request, description="Code verifier required."
121122
)
122123

123124
is_valid_code_challenge = authorization_code.check_code_challenge(
124125
request.post.code_verifier
125126
)
126127
if not is_valid_code_challenge:
127-
raise MismatchingStateError(request=request)
128+
raise MismatchingStateError[TRequest](request=request)
128129

129130
if authorization_code.is_expired:
130-
raise InvalidGrantError(request=request)
131+
raise InvalidGrantError[TRequest](request=request)
131132

132133
return client
133134

134135
async def create_token_response(
135-
self, request: Request, client: Client
136+
self, request: TRequest, client: Client
136137
) -> TokenResponse:
137138
token_response = await super().create_token_response(request, client)
138139

@@ -145,7 +146,7 @@ async def create_token_response(
145146
return token_response
146147

147148

148-
class PasswordGrantType(GrantTypeBase):
149+
class PasswordGrantType(GrantTypeBase[TRequest, TStorage]):
149150
"""
150151
The Password grant type is a way to exchange a user's credentials
151152
for an access token. Because the client application has to collect
@@ -156,25 +157,25 @@ class PasswordGrantType(GrantTypeBase):
156157
disallows the password grant entirely.
157158
"""
158159

159-
async def validate_request(self, request: Request) -> Client:
160+
async def validate_request(self, request: TRequest) -> Client:
160161
client = await super().validate_request(request)
161162

162163
if not request.post.username or not request.post.password:
163-
raise InvalidRequestError(
164+
raise InvalidRequestError[TRequest](
164165
request=request, description="Invalid credentials given."
165166
)
166167

167168
user = await self.storage.authenticate(request)
168169

169170
if not user:
170-
raise InvalidRequestError(
171+
raise InvalidRequestError[TRequest](
171172
request=request, description="Invalid credentials given."
172173
)
173174

174175
return client
175176

176177

177-
class RefreshTokenGrantType(GrantTypeBase):
178+
class RefreshTokenGrantType(GrantTypeBase[TRequest, TStorage]):
178179
"""
179180
The Refresh Token grant type is used by clients to exchange a
180181
refresh token for an access token when the access token has expired.
@@ -184,7 +185,7 @@ class RefreshTokenGrantType(GrantTypeBase):
184185
"""
185186

186187
async def create_token_response(
187-
self, request: Request, client: Client
188+
self, request: TRequest, client: Client
188189
) -> TokenResponse:
189190
"""Validate token request and create token response."""
190191
old_token = await self.storage.get_token(
@@ -194,7 +195,7 @@ async def create_token_response(
194195
)
195196

196197
if not old_token or old_token.revoked or old_token.refresh_token_expired:
197-
raise InvalidGrantError(request=request)
198+
raise InvalidGrantError[TRequest](request=request)
198199

199200
# Revoke old token
200201
await self.storage.revoke_token(
@@ -226,18 +227,18 @@ async def create_token_response(
226227
token_type=token.token_type,
227228
)
228229

229-
async def validate_request(self, request: Request) -> Client:
230+
async def validate_request(self, request: TRequest) -> Client:
230231
client = await super().validate_request(request)
231232

232233
if not request.post.refresh_token:
233-
raise InvalidRequestError(
234+
raise InvalidRequestError[TRequest](
234235
request=request, description="Missing refresh token parameter."
235236
)
236237

237238
return client
238239

239240

240-
class ClientCredentialsGrantType(GrantTypeBase):
241+
class ClientCredentialsGrantType(GrantTypeBase[TRequest, TStorage]):
241242
"""
242243
The Client Credentials grant type is used by clients to obtain an
243244
access token outside of the context of a user. This is typically

0 commit comments

Comments
 (0)