Skip to content

Commit c7bc2d7

Browse files
committed
checks: updated check component functions
1 parent 583c34d commit c7bc2d7

File tree

2 files changed

+253
-111
lines changed

2 files changed

+253
-111
lines changed

invenio_checks/api.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from datetime import datetime
2+
3+
from flask import current_app, g
4+
from invenio_communities.proxies import current_communities
5+
from invenio_db import db
6+
from invenio_rdm_records.requests import CommunityInclusion, CommunitySubmission
7+
from invenio_requests.proxies import current_requests_service
8+
from invenio_search.engine import dsl
9+
from sqlalchemy import case
10+
11+
from .models import CheckConfig, CheckRun, CheckRunStatus
12+
from .proxies import current_checks_registry
13+
14+
15+
class CheckRunAPI:
16+
"""Class for managing check runs."""
17+
18+
@classmethod
19+
def delete_check_run(cls, record_uuid, is_draft):
20+
"""Delete all draft check runs for the record."""
21+
CheckRun.query.filter_by(record_id=record_uuid, is_draft=is_draft).delete()
22+
try:
23+
db.session.commit()
24+
except Exception:
25+
current_app.logger.exception(
26+
"Failed to delete draft run for record %s", record_uuid
27+
)
28+
db.session.rollback()
29+
raise
30+
31+
@classmethod
32+
def resolve_checks(cls, record_uuid, record, request, community=None):
33+
"""Resolve the checks for this draft/record related to the community and the request."""
34+
enabled = current_app.config.get("CHECKS_ENABLED", False)
35+
if not enabled:
36+
return None
37+
38+
request_type = request.get("type")
39+
is_draft_submission = request_type == CommunitySubmission.type_id
40+
is_record_inclusion = request_type == CommunityInclusion.type_id
41+
42+
if not is_draft_submission and not is_record_inclusion:
43+
return None
44+
45+
if not record_uuid:
46+
return None
47+
48+
if not community:
49+
community_uuid = request.get("receiver", {}).get("community")
50+
if not community_uuid:
51+
return None
52+
community = current_communities.service.read(
53+
id_=community_uuid, identity=g.identity
54+
)
55+
56+
communities = []
57+
community_parent_id = community.to_dict().get("parent", {}).get("id")
58+
if community_parent_id:
59+
communities.append(community_parent_id)
60+
communities.append(community.id)
61+
62+
check_configs = (
63+
CheckConfig.query.filter(CheckConfig.community_id.in_(communities))
64+
.order_by(
65+
case((CheckConfig.community_id == communities[0], 0), else_=1),
66+
CheckConfig.check_id,
67+
)
68+
.all()
69+
)
70+
if not check_configs:
71+
return None
72+
73+
has_draft_run = record.data["is_draft"] or record._record.has_draft
74+
75+
check_config_ids = [cfg.id for cfg in check_configs]
76+
check_runs = CheckRun.query.filter(
77+
CheckRun.config_id.in_(check_config_ids),
78+
CheckRun.record_id == record_uuid,
79+
CheckRun.is_draft == has_draft_run,
80+
).all()
81+
82+
latest_checks = {}
83+
for run in check_runs:
84+
latest_checks.setdefault(run.config_id, run)
85+
86+
return [latest_checks[cid] for cid in check_config_ids if cid in latest_checks]
87+
88+
@classmethod
89+
def get_community_ids(cls , record, identity):
90+
"""Extract all relevant community IDs related to the record."""
91+
community_ids = set()
92+
93+
# Check draft review request
94+
if record.parent.review:
95+
community = record.parent.review.receiver.resolve()
96+
community_ids.add(str(community.id))
97+
community_parent_id = community.get("parent", {}).get("id")
98+
if community_parent_id:
99+
community_ids.add(community_parent_id)
100+
101+
# Check inclusion requests
102+
results = current_requests_service.search(
103+
identity,
104+
extra_filter=dsl.query.Bool(
105+
"must",
106+
must=[
107+
dsl.Q("term", **{"type": "community-inclusion"}),
108+
dsl.Q("term", **{"topic.record": record.pid.pid_value}),
109+
dsl.Q("term", **{"is_open": True}),
110+
],
111+
),
112+
)
113+
for result in results:
114+
community_id = result.get("receiver", {}).get("community")
115+
if community_id:
116+
community_ids.add(community_id)
117+
community = current_communities.service.read(
118+
id_=community_id, identity=identity
119+
)
120+
community_parent_id = community.to_dict().get("parent", {}).get("id")
121+
if community_parent_id:
122+
community_ids.add(community_parent_id)
123+
124+
# Check already included communities
125+
for community in record.parent.communities:
126+
community_ids.add(str(community.id))
127+
community_parent_id = community.get("parent", {}).get("id")
128+
if community_parent_id:
129+
community_ids.add(community_parent_id)
130+
131+
return community_ids
132+
133+
@classmethod
134+
def get_check_configs_from_communities(cls, community_ids):
135+
"""Retrieve check configurations for the given community IDs."""
136+
return CheckConfig.query.filter(
137+
CheckConfig.community_id.in_(community_ids)
138+
).all()
139+
140+
@classmethod
141+
def run_checks(cls, identity, is_draft, record=None, errors=None, **kwargs):
142+
"""Handler to run checks.
143+
144+
Args:
145+
identity: The identity of the user or system running the checks.
146+
record: The record to run checks against.
147+
errors: A list to append any errors found.
148+
community_ids: A set of community IDs to consider for running checks.
149+
"""
150+
if not current_app.config.get("CHECKS_ENABLED", False):
151+
return
152+
153+
community_ids = CheckRunAPI.get_community_ids(record, identity)
154+
155+
all_check_configs = CheckRunAPI.get_check_configs_from_communities(community_ids)
156+
157+
for check_config in all_check_configs:
158+
try:
159+
check_cls = current_checks_registry.get(check_config.check_id)
160+
start_time = datetime.utcnow()
161+
res = check_cls().run(record, check_config)
162+
if not res.sync:
163+
continue
164+
165+
check_errors = [
166+
{
167+
**error,
168+
"context": {"community": check_config.community_id},
169+
}
170+
for error in res.errors
171+
]
172+
errors.extend(check_errors)
173+
174+
latest_check = (
175+
CheckRun.query.filter(
176+
CheckRun.config_id == check_config.id,
177+
CheckRun.record_id == record.id,
178+
CheckRun.is_draft.is_(is_draft),
179+
)
180+
.first()
181+
)
182+
183+
if not latest_check:
184+
latest_check = CheckRun(
185+
config_id=check_config.id,
186+
record_id=record.id,
187+
is_draft=is_draft,
188+
revision_id=record.revision_id,
189+
start_time=start_time,
190+
end_time=datetime.utcnow(),
191+
status=CheckRunStatus.COMPLETED,
192+
state="",
193+
result=res.to_dict(),
194+
)
195+
else:
196+
latest_check.is_draft = is_draft
197+
latest_check.revision_id = record.revision_id
198+
latest_check.start_time = start_time
199+
latest_check.end_time = datetime.utcnow()
200+
latest_check.result = res.to_dict()
201+
202+
db.session.add(latest_check)
203+
db.session.commit()
204+
205+
except Exception:
206+
current_app.logger.exception(
207+
"Error running check on record",
208+
extra={
209+
"record_id": str(record.id),
210+
"check_config_id": str(check_config.id),
211+
},
212+
)

invenio_checks/components.py

Lines changed: 41 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -16,125 +16,55 @@
1616
from invenio_requests.proxies import current_requests_service
1717
from invenio_search.engine import dsl
1818

19+
from .api import CheckRunAPI
1920
from .models import CheckConfig, CheckRun, CheckRunStatus
2021
from .proxies import current_checks_registry
2122

2223

2324
class ChecksComponent(ServiceComponent):
2425
"""Checks component."""
2526

26-
@property
27-
def enabled(self):
28-
"""Return if checks are enabled."""
29-
return current_app.config.get("CHECKS_ENABLED", False)
30-
31-
def _run_checks(self, identity, data=None, record=None, errors=None, **kwargs):
32-
"""Handler to run checks."""
33-
if not self.enabled:
34-
return
35-
36-
community_ids = set()
37-
38-
# Check draft review request
39-
if record.parent.review:
40-
# drafts can only be submitted to one community
41-
community = record.parent.review.receiver.resolve()
42-
community_ids.add(str(community.id))
43-
community_parent_id = community.get("parent", {}).get("id")
44-
if community_parent_id:
45-
community_ids.add(community_parent_id)
46-
47-
# Check inclusion requests
48-
results = current_requests_service.search(
49-
identity,
50-
extra_filter=dsl.query.Bool(
51-
"must",
52-
must=[
53-
dsl.Q("term", **{"type": "community-inclusion"}),
54-
dsl.Q("term", **{"topic.record": record.pid.pid_value}),
55-
dsl.Q("term", **{"is_open": True}),
56-
],
57-
),
58-
)
59-
for result in results:
60-
community_id = result.get("receiver", {}).get("community")
61-
if community_id:
62-
community_ids.add(community_id)
63-
# check if it is a subcommunity
64-
community = current_communities.service.read(
65-
id_=community_id, identity=g.identity
27+
def read_draft(self, identity, draft=None, errors=None, **kwargs):
28+
community_ids = CheckRunAPI.get_community_ids(draft, identity)
29+
check_configs = CheckRunAPI.get_check_configs_from_communities(community_ids)
30+
for config in check_configs:
31+
check_run = (
32+
CheckRun.query.filter(
33+
CheckRun.config_id == config.id,
34+
CheckRun.record_id == draft.id,
6635
)
67-
community_parent_id = community.to_dict().get("parent", {}).get("id")
68-
if community_parent_id:
69-
community_ids.add(community_parent_id)
70-
71-
# Check already included communities
72-
for community in record.parent.communities:
73-
community_ids.add(str(community.id))
74-
community_parent_id = community.get("parent", {}).get("id")
75-
if community_parent_id:
76-
community_ids.add(community_parent_id)
77-
78-
all_check_configs = CheckConfig.query.filter(
79-
CheckConfig.community_id.in_(community_ids)
80-
).all()
81-
for check_config in all_check_configs:
82-
try:
83-
check_cls = current_checks_registry.get(check_config.check_id)
84-
start_time = datetime.utcnow()
85-
res = check_cls().run(record, check_config)
86-
if not res.sync:
87-
continue
88-
89-
check_errors = [
90-
{
91-
**error,
92-
"context": {"community": check_config.community_id},
36+
.order_by(CheckRun.start_time.desc())
37+
.first()
38+
)
39+
if check_run:
40+
for error in check_run.result.get("errors", []):
41+
field = error.get("field")
42+
if not field:
43+
continue
44+
45+
*parents, leaf = field.split(".")
46+
current = errors
47+
48+
for key in parents:
49+
current = current.setdefault(key, {})
50+
51+
current[leaf] = {
52+
"context": {"community": str(config.community_id)},
53+
"description": error.get("description", ""),
54+
"message": error.get("messages", []),
55+
"severity": error.get("severity", "error"),
9356
}
94-
for error in res.errors
95-
]
96-
errors.extend(check_errors)
97-
98-
latest_check = (
99-
CheckRun.query.filter(
100-
CheckRun.config_id == check_config.id,
101-
CheckRun.record_id == record.id,
102-
CheckRun.is_draft.is_(True),
103-
)
104-
.order_by(CheckRun.start_time.desc())
105-
.first()
106-
)
10757

108-
# FIXME: We should use the service
109-
if not latest_check:
110-
latest_check = CheckRun(
111-
config_id=check_config.id,
112-
record_id=record.id,
113-
is_draft=record.is_draft,
114-
revision_id=record.revision_id,
115-
start_time=start_time,
116-
end_time=datetime.utcnow(),
117-
status=CheckRunStatus.COMPLETED,
118-
state="",
119-
result=res.to_dict(),
120-
)
121-
else:
122-
latest_check.is_draft = record.is_draft
123-
latest_check.revision_id = record.revision_id
124-
latest_check.start_time = start_time
125-
latest_check.end_time = datetime.utcnow()
126-
latest_check.result = res.to_dict()
127-
128-
# Create/update the check run to the database
129-
self.uow.register(ModelCommitOp(latest_check))
130-
except Exception:
131-
current_app.logger.exception(
132-
"Error running check on record",
133-
extra={
134-
"record_id": str(record.id),
135-
"check_config_id": str(check_config.id),
136-
},
137-
)
58+
def update_draft(self, identity, data=None, record=None, errors=None, **kwargs):
59+
"""Run checks on draft update."""
60+
CheckRunAPI.run_checks(identity, is_draft=True, record=record, errors=errors)
61+
62+
def create(self, identity, data=None, record=None, errors=None, **kwargs):
63+
"""Run checks on draft create."""
64+
CheckRunAPI.run_checks(identity, is_draft=True, record=record, errors=errors)
65+
66+
def publish(self, identity, draft=None, record=None):
67+
"""Run checks on publish."""
68+
CheckRunAPI.delete_check_run(record_uuid=record.id, is_draft=False)
69+
CheckRunAPI.run_checks(identity, is_draft=False, record=record)
13870

139-
update_draft = _run_checks
140-
create = _run_checks

0 commit comments

Comments
 (0)