Skip to content

Commit 5355473

Browse files
rglaucoGiuseppe De Marco
and
Giuseppe De Marco
authored
feat: added method to show correct RP organization_name in OP pages (#305)
* feat: added get_client_organisation_name method to retrieve the correct RP name * chore: fix CIE organization_name * fix: updated cryptography rsa import to 42.0.2 * chore: bump to 1.3.1 * fix: corrected proposed change * fix: scope issue * Update spid_cie_oidc/provider/views/consent_page_view.py Co-authored-by: Giuseppe De Marco <[email protected]> * Update spid_cie_oidc/provider/views/__init__.py Co-authored-by: Giuseppe De Marco <[email protected]> * Update spid_cie_oidc/provider/views/authz_request_view.py Co-authored-by: Giuseppe De Marco <[email protected]> * fix: reinstated method name --------- Co-authored-by: Giuseppe De Marco <[email protected]>
1 parent 1faa95e commit 5355473

File tree

6 files changed

+56
-42
lines changed

6 files changed

+56
-42
lines changed

examples/provider/dumps/example.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
"metadata": {
148148
"federation_entity": {
149149
"federation_resolve_endpoint": "http://127.0.0.1:8002/oidc/op/resolve",
150-
"organization_name": "SPID OIDC identity provider",
150+
"organization_name": "CIE OIDC identity provider",
151151
"homepage_uri": "http://127.0.0.1:8002",
152152
"policy_uri": "http://127.0.0.1:8002/oidc/op/en/website/legal-information",
153153
"logo_uri": "http://127.0.0.1:8002/static/svg/logo-cie.svg",

spid_cie_oidc/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.3.0"
1+
__version__ = "1.3.1"

spid_cie_oidc/entity/jwks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from cryptojwt.jwk.rsa import new_rsa_key
33
from cryptography.hazmat.primitives import serialization
44
from cryptojwt.jwk.rsa import RSAKey
5-
5+
from cryptography.hazmat.primitives.asymmetric import rsa
66

77
import cryptography
88
from django.conf import settings
@@ -64,9 +64,9 @@ def serialize_rsa_key(rsa_key, kind="public", hash_func="SHA-256"):
6464
cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey
6565
"""
6666
data = {}
67-
if isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPublicKey):
67+
if isinstance(rsa_key, rsa.RSAPublicKey):
6868
data = {"pub_key": rsa_key}
69-
elif isinstance(rsa_key, cryptography.hazmat.backends.openssl.rsa._RSAPrivateKey):
69+
elif isinstance(rsa_key, rsa.RSAPrivateKey):
7070
data = {"priv_key": rsa_key}
7171
elif isinstance(rsa_key, (str, bytes)): # pragma: no cover
7272
if kind == "private":

spid_cie_oidc/provider/views/__init__.py

+49-31
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH,
3333
OIDCFED_PROVIDER_PROFILES_ID_TOKEN_CLAIMS
3434
)
35+
3536
logger = logging.getLogger(__name__)
3637

3738

@@ -40,7 +41,7 @@ class OpBase:
4041
Baseclass with common methods for OPs
4142
"""
4243

43-
def redirect_response_data(self, redirect_uri:str, **kwargs) -> HttpResponseRedirect:
44+
def redirect_response_data(self, redirect_uri: str, **kwargs) -> HttpResponseRedirect:
4445
if "?" in redirect_uri:
4546
qstring = "&"
4647
else:
@@ -114,7 +115,7 @@ def validate_authz_request_object(self, req) -> TrustChain:
114115

115116
jwks = get_jwks(
116117
rp_trust_chain.metadata['openid_relying_party'],
117-
federation_jwks = rp_trust_chain.jwks
118+
federation_jwks=rp_trust_chain.jwks
118119
)
119120
jwk = self.find_jwk(header, jwks)
120121
if not jwk:
@@ -178,7 +179,7 @@ def check_session(self, request) -> OidcSession:
178179
)
179180

180181
session_not_after = session.created + timezone.timedelta(
181-
minutes = OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
182+
minutes=OIDCFED_PROVIDER_AUTH_CODE_MAX_AGE
182183
)
183184
if session_not_after < timezone.localtime():
184185
raise ExpiredAuthCode(
@@ -199,12 +200,12 @@ def check_client_assertion(self, client_id: str, client_assertion: str) -> bool:
199200
_op = self.get_issuer()
200201
_op_eid = _op.sub
201202
_op_eid_authz_endpoint = [_op.metadata['openid_provider']['authorization_endpoint']]
202-
203+
203204
try:
204205
ClientAssertion(**payload)
205206
except Exception as e:
206207
raise Exception(f"Client Assertion: json schema validation error: {e}")
207-
208+
208209
if isinstance(_aud, str):
209210
_aud = [_aud]
210211
_allowed_auds = _aud + _op_eid_authz_endpoint
@@ -250,9 +251,9 @@ def get_jwt_common_data(self):
250251
}
251252

252253
def get_access_token(
253-
self, iss_sub:str, sub:str, authz: OidcSession, commons:dict
254+
self, iss_sub: str, sub: str, authz: OidcSession, commons: dict
254255
) -> dict:
255-
256+
256257
access_token = {
257258
"iss": iss_sub,
258259
"sub": sub,
@@ -266,8 +267,8 @@ def get_access_token(
266267
return access_token
267268

268269
def get_id_token_claims(
269-
self,
270-
authz:OidcSession
270+
self,
271+
authz: OidcSession
271272
) -> dict:
272273
_provider_profile = getattr(settings, 'OIDCFED_DEFAULT_PROVIDER_PROFILE', OIDCFED_DEFAULT_PROVIDER_PROFILE)
273274
claims = {}
@@ -276,21 +277,21 @@ def get_id_token_claims(
276277
return claims
277278

278279
for claim in (
279-
authz.authz_request.get(
280-
"claims", {}
281-
).get("id_token", {}).keys()
280+
authz.authz_request.get(
281+
"claims", {}
282+
).get("id_token", {}).keys()
282283
):
283284
if claim in allowed_id_token_claims and authz.user.attributes.get(claim, None):
284285
claims[claim] = authz.user.attributes[claim]
285286
return claims
286287

287288
def get_id_token(
288-
self,
289-
iss_sub:str,
290-
sub:str,
291-
authz:OidcSession,
292-
jwt_at:str,
293-
commons:dict
289+
self,
290+
iss_sub: str,
291+
sub: str,
292+
authz: OidcSession,
293+
jwt_at: str,
294+
commons: dict
294295
) -> dict:
295296

296297
id_token = {
@@ -312,19 +313,19 @@ def get_id_token(
312313

313314
def get_refresh_token(
314315
self,
315-
iss_sub:str,
316-
sub:str,
317-
authz:OidcSession,
318-
jwt_at:str,
319-
commons:dict
316+
iss_sub: str,
317+
sub: str,
318+
authz: OidcSession,
319+
jwt_at: str,
320+
commons: dict
320321
) -> dict:
321322
# refresh token is scope offline_access and prompt == consent
322323
refresh_acrs = OIDCFED_PROVIDER_PROFILES_ACR_4_REFRESH[OIDCFED_DEFAULT_PROVIDER_PROFILE]
323324
acrs = authz.authz_request.get('acr_values', [])
324325
if (
325-
"offline_access" in authz.authz_request['scope'] and
326-
'consent' in authz.authz_request['prompt'] and
327-
set(refresh_acrs).intersection(set(acrs))
326+
"offline_access" in authz.authz_request['scope'] and
327+
'consent' in authz.authz_request['prompt'] and
328+
set(refresh_acrs).intersection(set(acrs))
328329
):
329330
refresh_token = {
330331
"sub": sub,
@@ -337,8 +338,8 @@ def get_refresh_token(
337338
refresh_token.update(commons)
338339
return refresh_token
339340

340-
def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConfiguration):
341-
_sub = session.pairwised_sub(provider_id = issuer.sub)
341+
def get_iss_token_data(self, session: OidcSession, issuer: FederationEntityConfiguration):
342+
_sub = session.pairwised_sub(provider_id=issuer.sub)
342343
iss_sub = issuer.sub
343344
commons = self.get_jwt_common_data()
344345
jwk = issuer.jwks_core[0]
@@ -363,7 +364,7 @@ def get_iss_token_data(self, session : OidcSession, issuer: FederationEntityConf
363364

364365
def get_expires_in(self, iat: int, exp: int):
365366
return timezone.timedelta(
366-
seconds = exp - iat
367+
seconds=exp - iat
367368
).seconds
368369

369370
def attributes_names_to_release(self, request, session: OidcSession) -> dict:
@@ -391,6 +392,23 @@ def attributes_names_to_release(self, request, session: OidcSession) -> dict:
391392
for i in filtered_user_claims.keys()
392393
]
393394
return dict(
394-
i18n_user_claims = i18n_user_claims,
395-
filtered_user_claims = filtered_user_claims
395+
i18n_user_claims=i18n_user_claims,
396+
filtered_user_claims=filtered_user_claims
396397
)
398+
399+
def get_client_organization_name(self, tc):
400+
rp_metadata = (
401+
tc.metadata.get(
402+
"federation_entity", {}
403+
) or
404+
tc.metadata.get(
405+
"openid_relying_party", {}
406+
)
407+
)
408+
if rp_metadata:
409+
name = (
410+
rp_metadata.get("organization_name", "") or
411+
rp_metadata.get("client_name", "") or
412+
rp_metadata.get("client_id", "")
413+
)
414+
return name

spid_cie_oidc/provider/views/authz_request_view.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def get(self, request, *args, **kwargs):
199199
# stores the authz request in a hidden field in the form
200200
form = self.get_login_form()()
201201
context = {
202-
"client_organization_name": tc.metadata.get(
203-
"client_name", self.payload["client_id"]
204-
),
202+
"client_organization_name": self.get_client_organization_name(tc),
205203
"hidden_form": AuthzHiddenForm(dict(authz_request_object=req)),
206204
"form": form,
207205
"redirect_uri": self.payload["redirect_uri"],

spid_cie_oidc/provider/views/consent_page_view.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def get(self, request, *args, **kwargs):
5656
context = {
5757
"form": self.get_consent_form()(),
5858
"session": session,
59-
"client_organization_name": tc.metadata.get(
60-
"client_name", session.client_id
61-
),
59+
"client_organization_name": self.get_client_organization_name(tc),
6260
"user_claims": sorted(set(i18n_user_claims),),
6361
"redirect_uri": session.authz_request["redirect_uri"],
6462
"state": session.authz_request["state"]

0 commit comments

Comments
 (0)