Skip to content

Commit

Permalink
authx: fix wrong serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarocco committed Sep 30, 2024
1 parent 3db6984 commit a30d5d2
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 57 deletions.
19 changes: 10 additions & 9 deletions invenio_cern_sync/authz/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def get_authz_token(self):
"instituteName", # "CERN"
"preferredCernLanguage", # "EN"
"orcid",
"postOfficeBox",
"primaryAccountEmail",
# "postOfficeBox", # currently missing, maybe added later
]


Expand Down Expand Up @@ -111,7 +111,8 @@ def _fetch_all(self, url, headers):
_url = f"{url}&offset={offset}"
resp = request_with_retries(url=_url, method="GET", headers=headers)
total = resp.json()["pagination"]["total"]
yield resp.json()["data"]
total = 500 # for testing purposes
yield from resp.json()["data"]
offset += self.limit

max_threads = os.cpu_count()
Expand All @@ -122,14 +123,14 @@ def _fetch_all(self, url, headers):
_url = f"{url}&offset={offset}"
futures.append(
executor.submit(
request_with_retries(url=_url, method="GET", headers=headers)
request_with_retries, url=_url, method="GET", headers=headers
)
)
offset += self.limit

for future in concurrent.futures.as_completed(futures):
resp = future.result()
yield resp.json()["data"]
yield from resp.json()["data"]

def get_identities(self, fields=IDENTITY_FIELDS):
"""Get all identities.
Expand All @@ -144,13 +145,13 @@ def get_identities(self, fields=IDENTITY_FIELDS):
"accept": "application/json",
}

query_params = [("field", value) for value in fields]
query_params += [
query_params = [
("limit", self.limit),
("filter", "type:Person"),
("filter", "source:cern"),
("filter", "activeUser:true"),
]
query_params += [("field", value) for value in fields]
query_string = urlencode(query_params)

url_without_offset = f"{self.base_url}/api/v1.0/Identity?{query_string}"
Expand All @@ -164,11 +165,11 @@ def get_groups(self, fields=GROUPS_FIELDS):
"accept": "application/json",
}

query_params = [("field", value) for value in fields]
query_params += [
query_params = [
("limit", self.limit),
]
query_params += [("field", value) for value in fields]
query_string = urlencode(query_params)

url_without_offset = f"{self.base_url}/api/v1.0/Groups?{query_string}"
url_without_offset = f"{self.base_url}/api/v1.0/Group?{query_string}"
return self._fetch_all(url_without_offset, headers)
12 changes: 6 additions & 6 deletions invenio_cern_sync/authz/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def userprofile_mapper(cern_identity):
The returned dict structure must match the user profile schema defined via
the config ACCOUNTS_USER_PROFILE_SCHEMA."""
return dict(
affiliations=cern_identity["instituteName"],
cern_department=cern_identity["cernDepartment"],
cern_group=cern_identity["cernGroup"],
cern_section=cern_identity["cernSection"],
affiliations=cern_identity["instituteName"] or "",
department=cern_identity["cernDepartment"] or "",
family_name=cern_identity["lastName"],
full_name=cern_identity["displayName"],
given_name=cern_identity["firstName"],
mailbox=cern_identity.get("postOfficeBox", ""),
orcid=cern_identity.get("orcid", ""),
group=cern_identity["cernGroup"] or "",
mailbox=cern_identity["postOfficeBox"] or "",
orcid=cern_identity["orcid"] or "",
person_id=cern_identity["personId"],
section=cern_identity["cernSection"] or "",
)


Expand Down
4 changes: 3 additions & 1 deletion invenio_cern_sync/authz/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def serialize_cern_identity(cern_identity):
raise InvalidCERNIdentity("personId", "unknown")

try:
# cern_identity.get("preferredCernLanguage") or "en" # value can be None
language = "en" # Invenio supports only English for now
serialized = dict(
email=cern_identity["primaryAccountEmail"].lower(),
username=cern_identity["upn"].lower(),
user_profile=userprofile_mapper(cern_identity),
preferences=dict(locale=cern_identity["preferredCernLanguage"].lower()),
preferences=dict(locale=language.lower()),
user_identity_id=person_id,
remote_account_extra_data=extra_data_mapper(cern_identity),
)
Expand Down
6 changes: 3 additions & 3 deletions invenio_cern_sync/ldap/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def userprofile_mapper(ldap_user):
the config ACCOUNTS_USER_PROFILE_SCHEMA."""
return dict(
affiliations=first_or_default(ldap_user, "cernInstituteName"),
cern_department=first_or_default(ldap_user, "division"),
cern_group=first_or_default(ldap_user, "cernGroup"),
cern_section=first_or_default(ldap_user, "cernSection"),
department=first_or_default(ldap_user, "division"),
family_name=first_or_default(ldap_user, "sn"),
full_name=first_or_default(ldap_user, "displayName"),
given_name=first_or_default(ldap_user, "givenName"),
group=first_or_default(ldap_user, "cernGroup"),
mailbox=first_or_default(ldap_user, "postOfficeBox"),
person_id=first_or_default(ldap_user, "employeeID"),
section=first_or_default(ldap_user, "cernSection"),
)


Expand Down
6 changes: 3 additions & 3 deletions invenio_cern_sync/ldap/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def serialize_ldap_user(ldap_user, userprofile_mapper=None, extra_data_mapper=No
raise InvalidLdapUser("employeeID", "unknown")

try:
# first_or_default(ldap_user, "preferredLanguage", "en").lower()
language = "en" # Invenio supports only English for now
serialized = dict(
email=first_or_raise(ldap_user, "mail").lower(),
username=first_or_raise(ldap_user, "cn").lower(),
user_profile=userprofile_mapper(ldap_user),
preferences=dict(
locale=first_or_default(ldap_user, "preferredLanguage", "en").lower()
),
preferences=dict(locale=language),
user_identity_id=person_id,
remote_account_extra_data=extra_data_mapper(ldap_user),
)
Expand Down
2 changes: 2 additions & 0 deletions invenio_cern_sync/sso/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class _Form(Form):
######################################################################################
# User handler


def cern_setup_handler(remote, token, resp):
"""Perform additional setup after the user has been logged in."""
token_user_info, _ = get_user_info(remote, resp)
Expand Down Expand Up @@ -100,6 +101,7 @@ def cern_info_serializer(remote, resp, token_user_info, user_info):
######################################################################################
# Groups handler


def cern_groups_handler(remote, resp):
"""Retrieves groups from remote account.
Expand Down
12 changes: 4 additions & 8 deletions invenio_cern_sync/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@

"""Invenio-CERN-sync tasks."""

from flask import current_app
from celery import shared_task
from flask import current_app
from invenio_db import db

from .users.sync import sync as users_sync
from .groups.sync import sync as groups_sync
from .users.sync import sync as users_sync


@shared_task
def sync_users(*args, **kwargs):
"""Task to sync users with CERN database."""
if current_app.config.get("DEBUG", True):
current_app.logger.warning(
"Users sync disabled, the DEBUG env var is True."
)
current_app.logger.warning("Users sync disabled, the DEBUG env var is True.")
return

try:
Expand All @@ -35,9 +33,7 @@ def sync_users(*args, **kwargs):
def sync_groups(*args, **kwargs):
"""Task to sync groups with CERN database."""
if current_app.config.get("DEBUG", True):
current_app.logger.warning(
"Groups sync disabled, the DEBUG env var is True."
)
current_app.logger.warning("Groups sync disabled, the DEBUG env var is True.")
return

try:
Expand Down
38 changes: 25 additions & 13 deletions invenio_cern_sync/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from invenio_oauthclient.models import RemoteAccount, UserIdentity

from invenio_cern_sync.sso import cern_remote_app_name
from invenio_cern_sync.utils import _is_different
from invenio_cern_sync.utils import is_different


def _create_user(cern_user):
Expand Down Expand Up @@ -88,25 +88,25 @@ def create_user(cern_user, auto_confirm=True):

def _update_user(user, cern_user):
"""Update User table, when necessary."""
user_changed = (
user_updated = (
user.email != cern_user["email"]
or user.username != cern_user["username"].lower()
)
if user_changed:
if user_updated:
user.email = cern_user["email"]
user.username = cern_user["username"]

# check if any key/value in CERN is different from the local user.user_profile
local_up = user.user_profile
cern_up = cern_user["user_profile"]
up_changed = _is_different(cern_up, local_up)
if up_changed:
up_updated = is_different(cern_up, local_up)
if up_updated:
user.user_profile = {**dict(user.user_profile), **cern_up}

# check if any key/value in CERN is different from the local user.preferences
local_prefs = user.preferences
cern_prefs = cern_user["preferences"]
prefs_changed = (
prefs_updated = (
len(
[
key
Expand All @@ -116,23 +116,28 @@ def _update_user(user, cern_user):
)
> 0
)
if prefs_changed:
if prefs_updated:
user.preferences = {**dict(user.preferences), **cern_prefs}

return user_updated or up_updated or prefs_updated


def _update_useridentity(user_id, user_identity, cern_user):
"""Update User profile col, when necessary."""
changed = (
updated = (
user_identity.id != cern_user["user_identity_id"]
or user_identity.id_user != user_id
)
if changed:
if updated:
user_identity.id = cern_user["user_identity_id"]
user_identity.id_user = user_id

return updated


def _update_remote_account(user, cern_user):
"""Update RemoteAccount table."""
updated = False
extra_data = cern_user["remote_account_extra_data"]
client_id = current_app.config["CERN_APP_CREDENTIALS"]["consumer_key"]
assert client_id
Expand All @@ -141,12 +146,19 @@ def _update_remote_account(user, cern_user):
if not remote_account:
# should probably never happen
RemoteAccount.create(user.id, client_id, extra_data)
elif _is_different(remote_account.extra_data, extra_data):
updated = True
elif is_different(extra_data, remote_account.extra_data):
remote_account.extra_data.update(**extra_data)
updated = True

return updated


def update_existing_user(local_user, local_user_identity, cern_user):
"""Update all user tables, when necessary."""
_update_user(local_user, cern_user)
_update_useridentity(local_user.id, local_user_identity, cern_user)
_update_remote_account(local_user, cern_user)
user_updated = _update_user(local_user, cern_user)
identity_updated = _update_useridentity(
local_user.id, local_user_identity, cern_user
)
remote_updated = _update_remote_account(local_user, cern_user)
return user_updated or identity_updated or remote_updated
4 changes: 2 additions & 2 deletions invenio_cern_sync/users/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def _update_existing(users, serializer_fn, log_uuid):
user.id == user_identity.id_user
), f"User and UserIdentity are not correctly linked for user #{user.id} and user_identity #{user_identity.id}"

update_existing_user(user, user_identity, invenio_user)
updated.add(user.id)
if update_existing_user(user, user_identity, invenio_user):
updated.add(user.id)

# persist changes before starting with the inserting of missing users
db.session.commit()
Expand Down
17 changes: 5 additions & 12 deletions invenio_cern_sync/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,8 @@ def first_or_default(d, key, default=""):
return default


def _is_different(dict1, dict2):
"""Return true if they differ."""
return (
len(
[
key
for key in dict1.keys() | dict2.keys()
if dict1.get(key) != dict2.get(key)
]
)
> 0
)
def is_different(new_dict, existing_dict):
"""Return True new_dict has new keys or updated values."""
for key, value in new_dict.items():
if key not in existing_dict or existing_dict[key] != value:
return True

0 comments on commit a30d5d2

Please sign in to comment.