|
| 1 | +from datetime import datetime |
| 2 | + |
| 3 | +from flask import current_app, g |
| 4 | +from invenio_communities.proxies import current_communities |
| 5 | +from invenio_db import db |
| 6 | +from invenio_rdm_records.requests import CommunityInclusion, CommunitySubmission |
| 7 | +from invenio_requests.proxies import current_requests_service |
| 8 | +from invenio_search.engine import dsl |
| 9 | +from sqlalchemy import case |
| 10 | + |
| 11 | +from .models import CheckConfig, CheckRun, CheckRunStatus |
| 12 | +from .proxies import current_checks_registry |
| 13 | + |
| 14 | + |
| 15 | +class CheckRunAPI: |
| 16 | + """Class for managing check runs.""" |
| 17 | + |
| 18 | + @classmethod |
| 19 | + def delete_check_run(cls, record_uuid, is_draft): |
| 20 | + """Delete all draft check runs for the record.""" |
| 21 | + CheckRun.query.filter_by(record_id=record_uuid, is_draft=is_draft).delete() |
| 22 | + try: |
| 23 | + db.session.commit() |
| 24 | + except Exception: |
| 25 | + current_app.logger.exception( |
| 26 | + "Failed to delete draft run for record %s", record_uuid |
| 27 | + ) |
| 28 | + db.session.rollback() |
| 29 | + raise |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def resolve_checks(cls, record_uuid, record, request, community=None): |
| 33 | + """Resolve the checks for this draft/record related to the community and the request.""" |
| 34 | + enabled = current_app.config.get("CHECKS_ENABLED", False) |
| 35 | + if not enabled: |
| 36 | + return None |
| 37 | + |
| 38 | + request_type = request.get("type") |
| 39 | + is_draft_submission = request_type == CommunitySubmission.type_id |
| 40 | + is_record_inclusion = request_type == CommunityInclusion.type_id |
| 41 | + |
| 42 | + if not is_draft_submission and not is_record_inclusion: |
| 43 | + return None |
| 44 | + |
| 45 | + if not record_uuid: |
| 46 | + return None |
| 47 | + |
| 48 | + if not community: |
| 49 | + community_uuid = request.get("receiver", {}).get("community") |
| 50 | + if not community_uuid: |
| 51 | + return None |
| 52 | + community = current_communities.service.read( |
| 53 | + id_=community_uuid, identity=g.identity |
| 54 | + ) |
| 55 | + |
| 56 | + communities = [] |
| 57 | + community_parent_id = community.to_dict().get("parent", {}).get("id") |
| 58 | + if community_parent_id: |
| 59 | + communities.append(community_parent_id) |
| 60 | + communities.append(community.id) |
| 61 | + |
| 62 | + check_configs = ( |
| 63 | + CheckConfig.query.filter(CheckConfig.community_id.in_(communities)) |
| 64 | + .order_by( |
| 65 | + case((CheckConfig.community_id == communities[0], 0), else_=1), |
| 66 | + CheckConfig.check_id, |
| 67 | + ) |
| 68 | + .all() |
| 69 | + ) |
| 70 | + if not check_configs: |
| 71 | + return None |
| 72 | + |
| 73 | + has_draft_run = record.data["is_draft"] or record._record.has_draft |
| 74 | + |
| 75 | + check_config_ids = [cfg.id for cfg in check_configs] |
| 76 | + check_runs = CheckRun.query.filter( |
| 77 | + CheckRun.config_id.in_(check_config_ids), |
| 78 | + CheckRun.record_id == record_uuid, |
| 79 | + CheckRun.is_draft == has_draft_run, |
| 80 | + ).all() |
| 81 | + |
| 82 | + latest_checks = {} |
| 83 | + for run in check_runs: |
| 84 | + latest_checks.setdefault(run.config_id, run) |
| 85 | + |
| 86 | + return [latest_checks[cid] for cid in check_config_ids if cid in latest_checks] |
| 87 | + |
| 88 | + @classmethod |
| 89 | + def get_community_ids(cls , record, identity): |
| 90 | + """Extract all relevant community IDs related to the record.""" |
| 91 | + community_ids = set() |
| 92 | + |
| 93 | + # Check draft review request |
| 94 | + if record.parent.review: |
| 95 | + community = record.parent.review.receiver.resolve() |
| 96 | + community_ids.add(str(community.id)) |
| 97 | + community_parent_id = community.get("parent", {}).get("id") |
| 98 | + if community_parent_id: |
| 99 | + community_ids.add(community_parent_id) |
| 100 | + |
| 101 | + # Check inclusion requests |
| 102 | + results = current_requests_service.search( |
| 103 | + identity, |
| 104 | + extra_filter=dsl.query.Bool( |
| 105 | + "must", |
| 106 | + must=[ |
| 107 | + dsl.Q("term", **{"type": "community-inclusion"}), |
| 108 | + dsl.Q("term", **{"topic.record": record.pid.pid_value}), |
| 109 | + dsl.Q("term", **{"is_open": True}), |
| 110 | + ], |
| 111 | + ), |
| 112 | + ) |
| 113 | + for result in results: |
| 114 | + community_id = result.get("receiver", {}).get("community") |
| 115 | + if community_id: |
| 116 | + community_ids.add(community_id) |
| 117 | + community = current_communities.service.read( |
| 118 | + id_=community_id, identity=identity |
| 119 | + ) |
| 120 | + community_parent_id = community.to_dict().get("parent", {}).get("id") |
| 121 | + if community_parent_id: |
| 122 | + community_ids.add(community_parent_id) |
| 123 | + |
| 124 | + # Check already included communities |
| 125 | + for community in record.parent.communities: |
| 126 | + community_ids.add(str(community.id)) |
| 127 | + community_parent_id = community.get("parent", {}).get("id") |
| 128 | + if community_parent_id: |
| 129 | + community_ids.add(community_parent_id) |
| 130 | + |
| 131 | + return community_ids |
| 132 | + |
| 133 | + @classmethod |
| 134 | + def get_check_configs_from_communities(cls, community_ids): |
| 135 | + """Retrieve check configurations for the given community IDs.""" |
| 136 | + return CheckConfig.query.filter( |
| 137 | + CheckConfig.community_id.in_(community_ids) |
| 138 | + ).all() |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def run_checks(cls, identity, is_draft, record=None, errors=None, **kwargs): |
| 142 | + """Handler to run checks. |
| 143 | +
|
| 144 | + Args: |
| 145 | + identity: The identity of the user or system running the checks. |
| 146 | + record: The record to run checks against. |
| 147 | + errors: A list to append any errors found. |
| 148 | + community_ids: A set of community IDs to consider for running checks. |
| 149 | + """ |
| 150 | + if not current_app.config.get("CHECKS_ENABLED", False): |
| 151 | + return |
| 152 | + |
| 153 | + community_ids = CheckRunAPI.get_community_ids(record, identity) |
| 154 | + |
| 155 | + all_check_configs = CheckRunAPI.get_check_configs_from_communities(community_ids) |
| 156 | + |
| 157 | + for check_config in all_check_configs: |
| 158 | + try: |
| 159 | + check_cls = current_checks_registry.get(check_config.check_id) |
| 160 | + start_time = datetime.utcnow() |
| 161 | + res = check_cls().run(record, check_config) |
| 162 | + if not res.sync: |
| 163 | + continue |
| 164 | + |
| 165 | + check_errors = [ |
| 166 | + { |
| 167 | + **error, |
| 168 | + "context": {"community": check_config.community_id}, |
| 169 | + } |
| 170 | + for error in res.errors |
| 171 | + ] |
| 172 | + errors.extend(check_errors) |
| 173 | + |
| 174 | + latest_check = ( |
| 175 | + CheckRun.query.filter( |
| 176 | + CheckRun.config_id == check_config.id, |
| 177 | + CheckRun.record_id == record.id, |
| 178 | + CheckRun.is_draft.is_(is_draft), |
| 179 | + ) |
| 180 | + .first() |
| 181 | + ) |
| 182 | + |
| 183 | + if not latest_check: |
| 184 | + latest_check = CheckRun( |
| 185 | + config_id=check_config.id, |
| 186 | + record_id=record.id, |
| 187 | + is_draft=is_draft, |
| 188 | + revision_id=record.revision_id, |
| 189 | + start_time=start_time, |
| 190 | + end_time=datetime.utcnow(), |
| 191 | + status=CheckRunStatus.COMPLETED, |
| 192 | + state="", |
| 193 | + result=res.to_dict(), |
| 194 | + ) |
| 195 | + else: |
| 196 | + latest_check.is_draft = is_draft |
| 197 | + latest_check.revision_id = record.revision_id |
| 198 | + latest_check.start_time = start_time |
| 199 | + latest_check.end_time = datetime.utcnow() |
| 200 | + latest_check.result = res.to_dict() |
| 201 | + |
| 202 | + db.session.add(latest_check) |
| 203 | + db.session.commit() |
| 204 | + |
| 205 | + except Exception: |
| 206 | + current_app.logger.exception( |
| 207 | + "Error running check on record", |
| 208 | + extra={ |
| 209 | + "record_id": str(record.id), |
| 210 | + "check_config_id": str(check_config.id), |
| 211 | + }, |
| 212 | + ) |
0 commit comments