From 67a26fdf0f9a0c4df3052cf816c60edb2e608af4 Mon Sep 17 00:00:00 2001 From: ChristopherSpelt Date: Fri, 15 Nov 2024 13:28:21 +0100 Subject: [PATCH] Make things faster --- .vscode/launch.json | 4 +- :w | 63 ---------------------- amt/api/routes/algorithm.py | 29 ++++++---- amt/repositories/task_registry.py | 43 +++++++++------ amt/services/task_registry.py | 52 +----------------- tests/clients/test_clients.py | 1 + tests/services/test_instruments_service.py | 34 +++++++----- tests/services/test_instruments_state.py | 13 ----- 8 files changed, 72 insertions(+), 167 deletions(-) delete mode 100644 :w diff --git a/.vscode/launch.json b/.vscode/launch.json index e78cc94f..18d50b42 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -21,7 +21,9 @@ "DEBUG": "True", "AUTO_CREATE_SCHEMA": "True", "ENVIRONMENT": "local", - "LOGGING_LEVEL": "DEBUG" + "LOGGING_LEVEL": "DEBUG", + "OIDC_CLIENT_SECRET": "uIeFiKFazNEIbpJ3wzj0lZLLSJXefeld", + "OIDC_CLIENT_ID": "AMT" } }, { diff --git a/:w b/:w deleted file mode 100644 index 9519ce8a..00000000 --- a/:w +++ /dev/null @@ -1,63 +0,0 @@ -import logging -from collections.abc import Sequence - -from amt.clients.clients import TaskRegistryAPIClient, TaskType -from amt.schema.requirement import Requirement - -logger = logging.getLogger(__name__) - - -class RequirementsService: - def __init__(self, repository: TaskRegistryRepository) -> None: - self.repository = repository - - def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Requirement]: - """ - Fetches measures with the given URNs. - If urns contains an URN that is not a valid URN of an measure, it is simply ignored. - @param urns: URNs of instruments to fetch. If None, function returns all measures. - @return: List of measures with the given URNs in 'urns'. - """ - task_data = self.repository.fetch_tasks(TaskType.REQUIREMENTS, urns) - return [Requirement(**data) for data in task_data] - - -def create_requirements_service() -> RequirementsService: - client = TaskRegistryAPIClient() - repository = TaskRegistryRepository(client) - return RequirementsService(repository) - - -class RequirementsService: - def __init__(self) -> None: - self.client = TaskRegistryAPIClient() - - def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]: - """ - This functions returns requirement with given URN's. If urns contains an URN that is not a - valid URN of an requirement it is simply ignored. - - @param: URN's of requirements to fetch. If empty, function returns all requirements. - @return: List of requirements with given URN's in 'urns'. - """ - - if isinstance(urns, str): - urns = [urns] - - all_valid_urns = self.fetch_urns() - - if urns is not None: - return [ - Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)) - for urn in urns - if urn in all_valid_urns - ] - - return [Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)) for urn in all_valid_urns] - - def fetch_urns(self) -> list[str]: - """ - Fetches all valid requirement URN's. - """ - content_list = self.client.get_list_of_task(TaskType.REQUIREMENTS) - return [content.urn for content in content_list.root] diff --git a/amt/api/routes/algorithm.py b/amt/api/routes/algorithm.py index 96bd6a23..ad90780c 100644 --- a/amt/api/routes/algorithm.py +++ b/amt/api/routes/algorithm.py @@ -23,10 +23,10 @@ from amt.schema.requirement import RequirementTask from amt.schema.system_card import SystemCard from amt.schema.task import MovedTask -from amt.services import task_registry from amt.services.algorithms import AlgorithmsService from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService -from amt.services.task_registry import fetch_measures, fetch_requirements +from amt.services.measures import MeasuresService, create_measures_service +from amt.services.requirements import RequirementsService, create_requirements_service from amt.services.tasks import TasksService router = APIRouter() @@ -44,7 +44,10 @@ async def get_instrument_state(system_card: SystemCard) -> dict[str, Any]: async def get_requirements_state(system_card: SystemCard) -> dict[str, Any]: - requirements = await fetch_requirements([requirement.urn for requirement in system_card.requirements]) + requirements_service = create_requirements_service() + requirements = await requirements_service.fetch_requirements( + [requirement.urn for requirement in system_card.requirements] + ) requirements_state_service = RequirementsStateService(system_card) requirements_state = requirements_state_service.get_requirements_state(requirements) @@ -350,6 +353,8 @@ async def get_system_card_requirements( request: Request, algorithm_id: int, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], + requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)], + measures_service: Annotated[MeasuresService, Depends(create_measures_service)], ) -> HTMLResponse: algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request) instrument_state = await get_instrument_state(algorithm.system_card) @@ -367,13 +372,15 @@ async def get_system_card_requirements( request, ) - requirements = await fetch_requirements([requirement.urn for requirement in algorithm.system_card.requirements]) + requirements = await requirements_service.fetch_requirements( + [requirement.urn for requirement in algorithm.system_card.requirements] + ) # Get measures that correspond to the requirements and merge them with the measuretasks requirements_and_measures = [] for requirement in requirements: completed_measures_count = 0 - linked_measures = await fetch_measures(requirement.links) + linked_measures = await measures_service.fetch_measures(requirement.links) extended_linked_measures: list[ExtendedMeasureTask] = [] for measure in linked_measures: measure_task = find_measure_task(algorithm.system_card, measure.urn) @@ -424,10 +431,12 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure requirement_mapper[requirement_task.urn] = requirement_task requirement_tasks: list[RequirementTask] = [] - measure = await fetch_measures([measure_urn]) + measures_service = create_measures_service() + requirements_service = create_requirements_service() + measure = await measures_service.fetch_measures(measure_urn) for requirement_urn in measure[0].links: # TODO: This is because measure are linked to too many requirement not applicable in our use case - if len(await fetch_requirements([requirement_urn])) > 0: + if len(await requirements_service.fetch_requirements(requirement_urn)) > 0: requirement_tasks.append(requirement_mapper[requirement_urn]) return requirement_tasks @@ -451,7 +460,8 @@ async def get_measure( algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], ) -> HTMLResponse: algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request) - measure = await task_registry.fetch_measures([measure_urn]) + measures_service = create_measures_service() + measure = await measures_service.fetch_measures([measure_urn]) measure_task = find_measure_task(algorithm.system_card, measure_urn) context = { @@ -476,6 +486,7 @@ async def update_measure_value( measure_urn: str, measure_update: MeasureUpdate, algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)], + requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)], ) -> HTMLResponse: algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request) @@ -486,7 +497,7 @@ async def update_measure_value( # update for the linked requirements the state based on all it's measures requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn) requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks] - requirements = await fetch_requirements(requirement_urns) + requirements = await requirements_service.fetch_requirements(requirement_urns) for requirement in requirements: count_completed = 0 diff --git a/amt/repositories/task_registry.py b/amt/repositories/task_registry.py index f6d92fb2..140a2857 100644 --- a/amt/repositories/task_registry.py +++ b/amt/repositories/task_registry.py @@ -1,3 +1,4 @@ +import asyncio import logging from collections.abc import Sequence from typing import Any @@ -27,27 +28,16 @@ async def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | Non @param urns: URNs of tasks to fetch. If None, function returns all tasks of the given type. @return: List of task data dictionaries with the given URNs in 'urns'. """ - if isinstance(urns, str): - urns = [urns] - - all_valid_urns: list[str] = await self.fetch_urns(task_type) - if urns is None: - return [await self.client.get_task_by_urn(task_type, urn) for urn in all_valid_urns] + all_valid_urns: list[str] = await self._fetch_valid_urns(task_type) + return await self._fetch_tasks_by_urns(task_type, all_valid_urns) - tasks: list[dict[str, Any]] = [] - for urn in urns: - # For backward compatibilty of this method we now simply ignore invalid URN's. - # We might want to refactor this later to throw exceptions when task with URN is not - # found. - try: - tasks.append(await self.client.get_task_by_urn(task_type, urn)) - except AMTNotFound: - logger.warning(f"Cannot find {task_type.value} with URN {urn}") + if isinstance(urns, str): + urns = [urns] - return tasks + return await self._fetch_tasks_by_urns(task_type, urns) - async def fetch_urns(self, task_type: TaskType) -> list[str]: + async def _fetch_valid_urns(self, task_type: TaskType) -> list[str]: """ Fetches all valid URNs for the given task type. """ @@ -55,3 +45,22 @@ async def fetch_urns(self, task_type: TaskType) -> list[str]: content_list = await self.client.get_list_of_task(task_type) self._urn_cache[task_type] = [content.urn for content in content_list.root] return self._urn_cache[task_type] + + async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -> list[dict[str, Any]]: + """ + Fetches tasks given a list of URN's. + If an URN is not valid, it is ignored. + """ + get_tasks = [self.client.get_task_by_urn(task_type, urn) for urn in urns] + results = await asyncio.gather(*get_tasks, return_exceptions=True) + + tasks: list[dict[str, Any]] = [] + for result in results: + if isinstance(result, dict): + tasks.append(result) + elif isinstance(result, AMTNotFound): + logger.warning(f"Cannot find {task_type.value}") + else: + raise result + + return tasks diff --git a/amt/services/task_registry.py b/amt/services/task_registry.py index 1cfcf1da..a1ce9732 100644 --- a/amt/services/task_registry.py +++ b/amt/services/task_registry.py @@ -1,12 +1,8 @@ import logging -from collections.abc import Sequence -from functools import lru_cache -from amt.schema.measure import Measure, MeasureTask -from amt.schema.requirement import Requirement, RequirementTask +from amt.schema.measure import MeasureTask +from amt.schema.requirement import RequirementTask from amt.schema.system_card import AiActProfile -from amt.services.measures import create_measures_service -from amt.services.requirements import create_requirements_service logger = logging.getLogger(__name__) @@ -22,47 +18,3 @@ def get_requirements_and_measures( requirements_card: list[RequirementTask] = [] return requirements_card, measure_card - - -async def fetch_all_requirements() -> dict[str, Requirement]: - """ - Fetch requirements with URN in urns. - """ - requirement_service = create_requirements_service() - all_requirements = await requirement_service.fetch_requirements() - requirements: dict[str, Requirement] = {} - - for requirement in all_requirements: - requirements[requirement.urn] = requirement - - return requirements - - -async def fetch_requirements(urns: Sequence[str]) -> list[Requirement]: - """ - Fetch requirements with URN in urns. - """ - all_requirements = await fetch_all_requirements() - return [all_requirements[urn] for urn in urns if urn in all_requirements] - - -async def fetch_all_measures() -> dict[str, Measure]: - """ - Fetch measures with URN in urns. - """ - measure_service = create_measures_service() - all_measures = await measure_service.fetch_measures() - measures: dict[str, Measure] = {} - - for measure in all_measures: - measures[measure.urn] = measure - - return measures - - -async def fetch_measures(urns: Sequence[str]) -> list[Measure]: - """ - Fetch measures with URN in urns. - """ - all_measures = await fetch_all_measures() - return [all_measures[urn] for urn in urns if urn in all_measures] diff --git a/tests/clients/test_clients.py b/tests/clients/test_clients.py index f3b14890..42b25737 100644 --- a/tests/clients/test_clients.py +++ b/tests/clients/test_clients.py @@ -46,6 +46,7 @@ async def test_task_registry_api_client_get_instrument(httpx_mock: HTTPXMock): # then assert result == json.loads(TASK_REGISTRY_CONTENT_PAYLOAD) + @pytest.mark.asyncio async def test_task_registry_api_client_get_instrument_not_succesfull(httpx_mock: HTTPXMock): task_registry_api_client = TaskRegistryAPIClient() diff --git a/tests/services/test_instruments_service.py b/tests/services/test_instruments_service.py index 451811c6..019aa062 100644 --- a/tests/services/test_instruments_service.py +++ b/tests/services/test_instruments_service.py @@ -3,23 +3,11 @@ from amt.core.exceptions import AMTInstrumentError from amt.services.instruments import create_instrument_service from pytest_httpx import HTTPXMock -from tests.constants import TASK_REGISTRY_LIST_PAYLOAD +from tests.constants import TASK_REGISTRY_AIIA_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD # TODO(berry): made payloads to a better location -@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_urns.yml") # type: ignore -def test_fetch_urns(): - # given - instruments_service = create_instrument_service() - - ## when - # result = instruments_service.fetch_urns() - - ## then - # assert len(result) == 4 - - @vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_instruments.yml") # type: ignore @pytest.mark.asyncio async def test_fetch_instruments(): @@ -80,9 +68,27 @@ async def test_fetch_instruments_invalid(httpx_mock: HTTPXMock): # given instruments_service = create_instrument_service() httpx_mock.add_response( - url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() + url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", + content=b'{"test": 1}', ) + # then + with pytest.raises(AMTInstrumentError): + await instruments_service.fetch_instruments("urn:nl:aivt:tr:iama:1.0") + + +@pytest.mark.asyncio +async def test_fetch_instruments_invalid_without_specific_urn(httpx_mock: HTTPXMock): + # given + instruments_service = create_instrument_service() + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/instruments/", + content=TASK_REGISTRY_LIST_PAYLOAD.encode(), + ) + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:aiia:1.0?version=latest", + content=TASK_REGISTRY_AIIA_CONTENT_PAYLOAD.encode(), + ) httpx_mock.add_response( url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", content=b'{"test": 1}', diff --git a/tests/services/test_instruments_state.py b/tests/services/test_instruments_state.py index 3806d02b..094ab5c8 100644 --- a/tests/services/test_instruments_state.py +++ b/tests/services/test_instruments_state.py @@ -17,7 +17,6 @@ from tests.constants import ( TASK_REGISTRY_AIIA_CONTENT_PAYLOAD, TASK_REGISTRY_CONTENT_PAYLOAD, - TASK_REGISTRY_LIST_PAYLOAD, default_instrument, ) @@ -180,9 +179,6 @@ def test_find_next_tasks_for_instrument_correct_lifecycle(system_card: SystemCar @pytest.mark.asyncio async def test_get_state_per_instrument(system_card: SystemCard, httpx_mock: HTTPXMock): instrument_state_service = InstrumentStateService(system_card) - httpx_mock.add_response( - url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() - ) httpx_mock.add_response( url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), @@ -210,9 +206,6 @@ async def test_get_state_per_instrument(system_card: SystemCard, httpx_mock: HTT @pytest.mark.asyncio async def test_get_amount_completed_instruments(system_card: SystemCard, httpx_mock: HTTPXMock): instrument_state_service = InstrumentStateService(system_card) - httpx_mock.add_response( - url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() - ) httpx_mock.add_response( url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), @@ -234,9 +227,6 @@ async def test_get_amount_completed_instruments(system_card: SystemCard, httpx_m @pytest.mark.asyncio async def test_get_amount_total_instruments(system_card: SystemCard, httpx_mock: HTTPXMock): instrument_state_service = InstrumentStateService(system_card) - httpx_mock.add_response( - url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() - ) httpx_mock.add_response( url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), @@ -258,9 +248,6 @@ async def test_get_amount_total_instruments(system_card: SystemCard, httpx_mock: @pytest.mark.asyncio async def test_get_amount_completed_instruments_one_completed(system_card: SystemCard, httpx_mock: HTTPXMock): instrument_state_service = InstrumentStateService(system_card) - httpx_mock.add_response( - url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() - ) httpx_mock.add_response( url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest", content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(),