From a30d5d2700cf9a88235b6d869d5b9c5dbf6a737d Mon Sep 17 00:00:00 2001 From: Nicola Tarocco Date: Tue, 1 Oct 2024 00:00:46 +0200 Subject: [PATCH] authx: fix wrong serialization --- invenio_cern_sync/authz/client.py | 19 +++++++------- invenio_cern_sync/authz/mapper.py | 12 ++++----- invenio_cern_sync/authz/serializer.py | 4 ++- invenio_cern_sync/ldap/mapper.py | 6 ++--- invenio_cern_sync/ldap/serializer.py | 6 ++--- invenio_cern_sync/sso/api.py | 2 ++ invenio_cern_sync/tasks.py | 12 +++------ invenio_cern_sync/users/api.py | 38 ++++++++++++++++++--------- invenio_cern_sync/users/sync.py | 4 +-- invenio_cern_sync/utils.py | 17 ++++-------- 10 files changed, 63 insertions(+), 57 deletions(-) diff --git a/invenio_cern_sync/authz/client.py b/invenio_cern_sync/authz/client.py index f539267..67f8ce5 100644 --- a/invenio_cern_sync/authz/client.py +++ b/invenio_cern_sync/authz/client.py @@ -80,8 +80,8 @@ def get_authz_token(self): "instituteName", # "CERN" "preferredCernLanguage", # "EN" "orcid", + "postOfficeBox", "primaryAccountEmail", - # "postOfficeBox", # currently missing, maybe added later ] @@ -111,7 +111,8 @@ def _fetch_all(self, url, headers): _url = f"{url}&offset={offset}" resp = request_with_retries(url=_url, method="GET", headers=headers) total = resp.json()["pagination"]["total"] - yield resp.json()["data"] + total = 500 # for testing purposes + yield from resp.json()["data"] offset += self.limit max_threads = os.cpu_count() @@ -122,14 +123,14 @@ def _fetch_all(self, url, headers): _url = f"{url}&offset={offset}" futures.append( executor.submit( - request_with_retries(url=_url, method="GET", headers=headers) + request_with_retries, url=_url, method="GET", headers=headers ) ) offset += self.limit for future in concurrent.futures.as_completed(futures): resp = future.result() - yield resp.json()["data"] + yield from resp.json()["data"] def get_identities(self, fields=IDENTITY_FIELDS): """Get all identities. @@ -144,13 +145,13 @@ def get_identities(self, fields=IDENTITY_FIELDS): "accept": "application/json", } - query_params = [("field", value) for value in fields] - query_params += [ + query_params = [ ("limit", self.limit), ("filter", "type:Person"), ("filter", "source:cern"), ("filter", "activeUser:true"), ] + query_params += [("field", value) for value in fields] query_string = urlencode(query_params) url_without_offset = f"{self.base_url}/api/v1.0/Identity?{query_string}" @@ -164,11 +165,11 @@ def get_groups(self, fields=GROUPS_FIELDS): "accept": "application/json", } - query_params = [("field", value) for value in fields] - query_params += [ + query_params = [ ("limit", self.limit), ] + query_params += [("field", value) for value in fields] query_string = urlencode(query_params) - url_without_offset = f"{self.base_url}/api/v1.0/Groups?{query_string}" + url_without_offset = f"{self.base_url}/api/v1.0/Group?{query_string}" return self._fetch_all(url_without_offset, headers) diff --git a/invenio_cern_sync/authz/mapper.py b/invenio_cern_sync/authz/mapper.py index fcbc8d0..8d2370d 100644 --- a/invenio_cern_sync/authz/mapper.py +++ b/invenio_cern_sync/authz/mapper.py @@ -14,16 +14,16 @@ def userprofile_mapper(cern_identity): The returned dict structure must match the user profile schema defined via the config ACCOUNTS_USER_PROFILE_SCHEMA.""" return dict( - affiliations=cern_identity["instituteName"], - cern_department=cern_identity["cernDepartment"], - cern_group=cern_identity["cernGroup"], - cern_section=cern_identity["cernSection"], + affiliations=cern_identity["instituteName"] or "", + department=cern_identity["cernDepartment"] or "", family_name=cern_identity["lastName"], full_name=cern_identity["displayName"], given_name=cern_identity["firstName"], - mailbox=cern_identity.get("postOfficeBox", ""), - orcid=cern_identity.get("orcid", ""), + group=cern_identity["cernGroup"] or "", + mailbox=cern_identity["postOfficeBox"] or "", + orcid=cern_identity["orcid"] or "", person_id=cern_identity["personId"], + section=cern_identity["cernSection"] or "", ) diff --git a/invenio_cern_sync/authz/serializer.py b/invenio_cern_sync/authz/serializer.py index 53a5e0c..faee16b 100644 --- a/invenio_cern_sync/authz/serializer.py +++ b/invenio_cern_sync/authz/serializer.py @@ -24,11 +24,13 @@ def serialize_cern_identity(cern_identity): raise InvalidCERNIdentity("personId", "unknown") try: + # cern_identity.get("preferredCernLanguage") or "en" # value can be None + language = "en" # Invenio supports only English for now serialized = dict( email=cern_identity["primaryAccountEmail"].lower(), username=cern_identity["upn"].lower(), user_profile=userprofile_mapper(cern_identity), - preferences=dict(locale=cern_identity["preferredCernLanguage"].lower()), + preferences=dict(locale=language.lower()), user_identity_id=person_id, remote_account_extra_data=extra_data_mapper(cern_identity), ) diff --git a/invenio_cern_sync/ldap/mapper.py b/invenio_cern_sync/ldap/mapper.py index a9a16b4..88a395d 100644 --- a/invenio_cern_sync/ldap/mapper.py +++ b/invenio_cern_sync/ldap/mapper.py @@ -17,14 +17,14 @@ def userprofile_mapper(ldap_user): the config ACCOUNTS_USER_PROFILE_SCHEMA.""" return dict( affiliations=first_or_default(ldap_user, "cernInstituteName"), - cern_department=first_or_default(ldap_user, "division"), - cern_group=first_or_default(ldap_user, "cernGroup"), - cern_section=first_or_default(ldap_user, "cernSection"), + department=first_or_default(ldap_user, "division"), family_name=first_or_default(ldap_user, "sn"), full_name=first_or_default(ldap_user, "displayName"), given_name=first_or_default(ldap_user, "givenName"), + group=first_or_default(ldap_user, "cernGroup"), mailbox=first_or_default(ldap_user, "postOfficeBox"), person_id=first_or_default(ldap_user, "employeeID"), + section=first_or_default(ldap_user, "cernSection"), ) diff --git a/invenio_cern_sync/ldap/serializer.py b/invenio_cern_sync/ldap/serializer.py index d372860..0e94266 100644 --- a/invenio_cern_sync/ldap/serializer.py +++ b/invenio_cern_sync/ldap/serializer.py @@ -29,13 +29,13 @@ def serialize_ldap_user(ldap_user, userprofile_mapper=None, extra_data_mapper=No raise InvalidLdapUser("employeeID", "unknown") try: + # first_or_default(ldap_user, "preferredLanguage", "en").lower() + language = "en" # Invenio supports only English for now serialized = dict( email=first_or_raise(ldap_user, "mail").lower(), username=first_or_raise(ldap_user, "cn").lower(), user_profile=userprofile_mapper(ldap_user), - preferences=dict( - locale=first_or_default(ldap_user, "preferredLanguage", "en").lower() - ), + preferences=dict(locale=language), user_identity_id=person_id, remote_account_extra_data=extra_data_mapper(ldap_user), ) diff --git a/invenio_cern_sync/sso/api.py b/invenio_cern_sync/sso/api.py index 0997e13..d9fd738 100644 --- a/invenio_cern_sync/sso/api.py +++ b/invenio_cern_sync/sso/api.py @@ -35,6 +35,7 @@ class _Form(Form): ###################################################################################### # User handler + def cern_setup_handler(remote, token, resp): """Perform additional setup after the user has been logged in.""" token_user_info, _ = get_user_info(remote, resp) @@ -100,6 +101,7 @@ def cern_info_serializer(remote, resp, token_user_info, user_info): ###################################################################################### # Groups handler + def cern_groups_handler(remote, resp): """Retrieves groups from remote account. diff --git a/invenio_cern_sync/tasks.py b/invenio_cern_sync/tasks.py index c54ae54..88e5196 100644 --- a/invenio_cern_sync/tasks.py +++ b/invenio_cern_sync/tasks.py @@ -7,21 +7,19 @@ """Invenio-CERN-sync tasks.""" -from flask import current_app from celery import shared_task +from flask import current_app from invenio_db import db -from .users.sync import sync as users_sync from .groups.sync import sync as groups_sync +from .users.sync import sync as users_sync @shared_task def sync_users(*args, **kwargs): """Task to sync users with CERN database.""" if current_app.config.get("DEBUG", True): - current_app.logger.warning( - "Users sync disabled, the DEBUG env var is True." - ) + current_app.logger.warning("Users sync disabled, the DEBUG env var is True.") return try: @@ -35,9 +33,7 @@ def sync_users(*args, **kwargs): def sync_groups(*args, **kwargs): """Task to sync groups with CERN database.""" if current_app.config.get("DEBUG", True): - current_app.logger.warning( - "Groups sync disabled, the DEBUG env var is True." - ) + current_app.logger.warning("Groups sync disabled, the DEBUG env var is True.") return try: diff --git a/invenio_cern_sync/users/api.py b/invenio_cern_sync/users/api.py index 4398bc0..0b02e19 100644 --- a/invenio_cern_sync/users/api.py +++ b/invenio_cern_sync/users/api.py @@ -14,7 +14,7 @@ from invenio_oauthclient.models import RemoteAccount, UserIdentity from invenio_cern_sync.sso import cern_remote_app_name -from invenio_cern_sync.utils import _is_different +from invenio_cern_sync.utils import is_different def _create_user(cern_user): @@ -88,25 +88,25 @@ def create_user(cern_user, auto_confirm=True): def _update_user(user, cern_user): """Update User table, when necessary.""" - user_changed = ( + user_updated = ( user.email != cern_user["email"] or user.username != cern_user["username"].lower() ) - if user_changed: + if user_updated: user.email = cern_user["email"] user.username = cern_user["username"] # check if any key/value in CERN is different from the local user.user_profile local_up = user.user_profile cern_up = cern_user["user_profile"] - up_changed = _is_different(cern_up, local_up) - if up_changed: + up_updated = is_different(cern_up, local_up) + if up_updated: user.user_profile = {**dict(user.user_profile), **cern_up} # check if any key/value in CERN is different from the local user.preferences local_prefs = user.preferences cern_prefs = cern_user["preferences"] - prefs_changed = ( + prefs_updated = ( len( [ key @@ -116,23 +116,28 @@ def _update_user(user, cern_user): ) > 0 ) - if prefs_changed: + if prefs_updated: user.preferences = {**dict(user.preferences), **cern_prefs} + return user_updated or up_updated or prefs_updated + def _update_useridentity(user_id, user_identity, cern_user): """Update User profile col, when necessary.""" - changed = ( + updated = ( user_identity.id != cern_user["user_identity_id"] or user_identity.id_user != user_id ) - if changed: + if updated: user_identity.id = cern_user["user_identity_id"] user_identity.id_user = user_id + return updated + def _update_remote_account(user, cern_user): """Update RemoteAccount table.""" + updated = False extra_data = cern_user["remote_account_extra_data"] client_id = current_app.config["CERN_APP_CREDENTIALS"]["consumer_key"] assert client_id @@ -141,12 +146,19 @@ def _update_remote_account(user, cern_user): if not remote_account: # should probably never happen RemoteAccount.create(user.id, client_id, extra_data) - elif _is_different(remote_account.extra_data, extra_data): + updated = True + elif is_different(extra_data, remote_account.extra_data): remote_account.extra_data.update(**extra_data) + updated = True + + return updated def update_existing_user(local_user, local_user_identity, cern_user): """Update all user tables, when necessary.""" - _update_user(local_user, cern_user) - _update_useridentity(local_user.id, local_user_identity, cern_user) - _update_remote_account(local_user, cern_user) + user_updated = _update_user(local_user, cern_user) + identity_updated = _update_useridentity( + local_user.id, local_user_identity, cern_user + ) + remote_updated = _update_remote_account(local_user, cern_user) + return user_updated or identity_updated or remote_updated diff --git a/invenio_cern_sync/users/sync.py b/invenio_cern_sync/users/sync.py index 5e25944..0d556fe 100644 --- a/invenio_cern_sync/users/sync.py +++ b/invenio_cern_sync/users/sync.py @@ -149,8 +149,8 @@ def _update_existing(users, serializer_fn, log_uuid): user.id == user_identity.id_user ), f"User and UserIdentity are not correctly linked for user #{user.id} and user_identity #{user_identity.id}" - update_existing_user(user, user_identity, invenio_user) - updated.add(user.id) + if update_existing_user(user, user_identity, invenio_user): + updated.add(user.id) # persist changes before starting with the inserting of missing users db.session.commit() diff --git a/invenio_cern_sync/utils.py b/invenio_cern_sync/utils.py index 92d70b5..cc40f63 100644 --- a/invenio_cern_sync/utils.py +++ b/invenio_cern_sync/utils.py @@ -21,15 +21,8 @@ def first_or_default(d, key, default=""): return default -def _is_different(dict1, dict2): - """Return true if they differ.""" - return ( - len( - [ - key - for key in dict1.keys() | dict2.keys() - if dict1.get(key) != dict2.get(key) - ] - ) - > 0 - ) +def is_different(new_dict, existing_dict): + """Return True new_dict has new keys or updated values.""" + for key, value in new_dict.items(): + if key not in existing_dict or existing_dict[key] != value: + return True