Skip to content

Commit

Permalink
Make things faster
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 15, 2024
1 parent 8518c10 commit 67a26fd
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 167 deletions.
4 changes: 3 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"DEBUG": "True",
"AUTO_CREATE_SCHEMA": "True",
"ENVIRONMENT": "local",
"LOGGING_LEVEL": "DEBUG"
"LOGGING_LEVEL": "DEBUG",
"OIDC_CLIENT_SECRET": "uIeFiKFazNEIbpJ3wzj0lZLLSJXefeld",
"OIDC_CLIENT_ID": "AMT"
}
},
{
Expand Down
63 changes: 0 additions & 63 deletions :w

This file was deleted.

29 changes: 20 additions & 9 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import SystemCard
from amt.schema.task import MovedTask
from amt.services import task_registry
from amt.services.algorithms import AlgorithmsService
from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService
from amt.services.task_registry import fetch_measures, fetch_requirements
from amt.services.measures import MeasuresService, create_measures_service
from amt.services.requirements import RequirementsService, create_requirements_service
from amt.services.tasks import TasksService

router = APIRouter()
Expand All @@ -44,7 +44,10 @@ async def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:


async def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements = await fetch_requirements([requirement.urn for requirement in system_card.requirements])
requirements_service = create_requirements_service()
requirements = await requirements_service.fetch_requirements(
[requirement.urn for requirement in system_card.requirements]
)
requirements_state_service = RequirementsStateService(system_card)
requirements_state = requirements_state_service.get_requirements_state(requirements)

Expand Down Expand Up @@ -350,6 +353,8 @@ async def get_system_card_requirements(
request: Request,
algorithm_id: int,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = await get_instrument_state(algorithm.system_card)
Expand All @@ -367,13 +372,15 @@ async def get_system_card_requirements(
request,
)

requirements = await fetch_requirements([requirement.urn for requirement in algorithm.system_card.requirements])
requirements = await requirements_service.fetch_requirements(
[requirement.urn for requirement in algorithm.system_card.requirements]
)

# Get measures that correspond to the requirements and merge them with the measuretasks
requirements_and_measures = []
for requirement in requirements:
completed_measures_count = 0
linked_measures = await fetch_measures(requirement.links)
linked_measures = await measures_service.fetch_measures(requirement.links)
extended_linked_measures: list[ExtendedMeasureTask] = []
for measure in linked_measures:
measure_task = find_measure_task(algorithm.system_card, measure.urn)
Expand Down Expand Up @@ -424,10 +431,12 @@ async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure
requirement_mapper[requirement_task.urn] = requirement_task

requirement_tasks: list[RequirementTask] = []
measure = await fetch_measures([measure_urn])
measures_service = create_measures_service()
requirements_service = create_requirements_service()
measure = await measures_service.fetch_measures(measure_urn)
for requirement_urn in measure[0].links:
# TODO: This is because measure are linked to too many requirement not applicable in our use case
if len(await fetch_requirements([requirement_urn])) > 0:
if len(await requirements_service.fetch_requirements(requirement_urn)) > 0:
requirement_tasks.append(requirement_mapper[requirement_urn])

return requirement_tasks
Expand All @@ -451,7 +460,8 @@ async def get_measure(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
measure = await task_registry.fetch_measures([measure_urn])
measures_service = create_measures_service()
measure = await measures_service.fetch_measures([measure_urn])
measure_task = find_measure_task(algorithm.system_card, measure_urn)

context = {
Expand All @@ -476,6 +486,7 @@ async def update_measure_value(
measure_urn: str,
measure_update: MeasureUpdate,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)

Expand All @@ -486,7 +497,7 @@ async def update_measure_value(
# update for the linked requirements the state based on all it's measures
requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks]
requirements = await fetch_requirements(requirement_urns)
requirements = await requirements_service.fetch_requirements(requirement_urns)

for requirement in requirements:
count_completed = 0
Expand Down
43 changes: 26 additions & 17 deletions amt/repositories/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from collections.abc import Sequence
from typing import Any
Expand Down Expand Up @@ -27,31 +28,39 @@ async def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | Non
@param urns: URNs of tasks to fetch. If None, function returns all tasks of the given type.
@return: List of task data dictionaries with the given URNs in 'urns'.
"""
if isinstance(urns, str):
urns = [urns]

all_valid_urns: list[str] = await self.fetch_urns(task_type)

if urns is None:
return [await self.client.get_task_by_urn(task_type, urn) for urn in all_valid_urns]
all_valid_urns: list[str] = await self._fetch_valid_urns(task_type)
return await self._fetch_tasks_by_urns(task_type, all_valid_urns)

tasks: list[dict[str, Any]] = []
for urn in urns:
# For backward compatibilty of this method we now simply ignore invalid URN's.
# We might want to refactor this later to throw exceptions when task with URN is not
# found.
try:
tasks.append(await self.client.get_task_by_urn(task_type, urn))
except AMTNotFound:
logger.warning(f"Cannot find {task_type.value} with URN {urn}")
if isinstance(urns, str):
urns = [urns]

return tasks
return await self._fetch_tasks_by_urns(task_type, urns)

async def fetch_urns(self, task_type: TaskType) -> list[str]:
async def _fetch_valid_urns(self, task_type: TaskType) -> list[str]:
"""
Fetches all valid URNs for the given task type.
"""
if task_type not in self._urn_cache:
content_list = await self.client.get_list_of_task(task_type)
self._urn_cache[task_type] = [content.urn for content in content_list.root]
return self._urn_cache[task_type]

async def _fetch_tasks_by_urns(self, task_type: TaskType, urns: Sequence[str]) -> list[dict[str, Any]]:
"""
Fetches tasks given a list of URN's.
If an URN is not valid, it is ignored.
"""
get_tasks = [self.client.get_task_by_urn(task_type, urn) for urn in urns]
results = await asyncio.gather(*get_tasks, return_exceptions=True)

tasks: list[dict[str, Any]] = []
for result in results:
if isinstance(result, dict):
tasks.append(result)
elif isinstance(result, AMTNotFound):
logger.warning(f"Cannot find {task_type.value}")
else:
raise result

return tasks
52 changes: 2 additions & 50 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import logging
from collections.abc import Sequence
from functools import lru_cache

from amt.schema.measure import Measure, MeasureTask
from amt.schema.requirement import Requirement, RequirementTask
from amt.schema.measure import MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import AiActProfile
from amt.services.measures import create_measures_service
from amt.services.requirements import create_requirements_service

logger = logging.getLogger(__name__)

Expand All @@ -22,47 +18,3 @@ def get_requirements_and_measures(
requirements_card: list[RequirementTask] = []

return requirements_card, measure_card


async def fetch_all_requirements() -> dict[str, Requirement]:
"""
Fetch requirements with URN in urns.
"""
requirement_service = create_requirements_service()
all_requirements = await requirement_service.fetch_requirements()
requirements: dict[str, Requirement] = {}

for requirement in all_requirements:
requirements[requirement.urn] = requirement

return requirements


async def fetch_requirements(urns: Sequence[str]) -> list[Requirement]:
"""
Fetch requirements with URN in urns.
"""
all_requirements = await fetch_all_requirements()
return [all_requirements[urn] for urn in urns if urn in all_requirements]


async def fetch_all_measures() -> dict[str, Measure]:
"""
Fetch measures with URN in urns.
"""
measure_service = create_measures_service()
all_measures = await measure_service.fetch_measures()
measures: dict[str, Measure] = {}

for measure in all_measures:
measures[measure.urn] = measure

return measures


async def fetch_measures(urns: Sequence[str]) -> list[Measure]:
"""
Fetch measures with URN in urns.
"""
all_measures = await fetch_all_measures()
return [all_measures[urn] for urn in urns if urn in all_measures]
1 change: 1 addition & 0 deletions tests/clients/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def test_task_registry_api_client_get_instrument(httpx_mock: HTTPXMock):
# then
assert result == json.loads(TASK_REGISTRY_CONTENT_PAYLOAD)


@pytest.mark.asyncio
async def test_task_registry_api_client_get_instrument_not_succesfull(httpx_mock: HTTPXMock):
task_registry_api_client = TaskRegistryAPIClient()
Expand Down
34 changes: 20 additions & 14 deletions tests/services/test_instruments_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,11 @@
from amt.core.exceptions import AMTInstrumentError
from amt.services.instruments import create_instrument_service
from pytest_httpx import HTTPXMock
from tests.constants import TASK_REGISTRY_LIST_PAYLOAD
from tests.constants import TASK_REGISTRY_AIIA_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD

# TODO(berry): made payloads to a better location


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_urns.yml") # type: ignore
def test_fetch_urns():
# given
instruments_service = create_instrument_service()

## when
# result = instruments_service.fetch_urns()

## then
# assert len(result) == 4


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_instruments.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_instruments():
Expand Down Expand Up @@ -80,9 +68,27 @@ async def test_fetch_instruments_invalid(httpx_mock: HTTPXMock):
# given
instruments_service = create_instrument_service()
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode()
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest",
content=b'{"test": 1}',
)

# then
with pytest.raises(AMTInstrumentError):
await instruments_service.fetch_instruments("urn:nl:aivt:tr:iama:1.0")


@pytest.mark.asyncio
async def test_fetch_instruments_invalid_without_specific_urn(httpx_mock: HTTPXMock):
# given
instruments_service = create_instrument_service()
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/",
content=TASK_REGISTRY_LIST_PAYLOAD.encode(),
)
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:aiia:1.0?version=latest",
content=TASK_REGISTRY_AIIA_CONTENT_PAYLOAD.encode(),
)
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest",
content=b'{"test": 1}',
Expand Down
Loading

0 comments on commit 67a26fd

Please sign in to comment.