Skip to content

Commit

Permalink
Make API client async
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 15, 2024
1 parent 65fe3e5 commit 8518c10
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 96 deletions.
60 changes: 30 additions & 30 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@
logger = logging.getLogger(__name__)


def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:
async def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:
instrument_state = InstrumentStateService(system_card)
instrument_states = instrument_state.get_state_per_instrument()
instrument_states = await instrument_state.get_state_per_instrument()
return {
"instrument_states": instrument_states,
"count_0": instrument_state.get_amount_completed_instruments(),
"count_1": instrument_state.get_amount_total_instruments(),
}


def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements = fetch_requirements([requirement.urn for requirement in system_card.requirements])
async def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements = await fetch_requirements([requirement.urn for requirement in system_card.requirements])
requirements_state_service = RequirementsStateService(system_card)
requirements_state = requirements_state_service.get_requirements_state(requirements)

Expand Down Expand Up @@ -106,8 +106,8 @@ async def get_tasks(
tasks_service: Annotated[TasksService, Depends(TasksService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)
tasks_by_status = await gather_algorithm_tasks(algorithm_id, task_service=tasks_service)

Expand Down Expand Up @@ -170,8 +170,8 @@ async def get_algorithm_context(
algorithm_id: int, algorithms_service: AlgorithmsService, request: Request
) -> tuple[Algorithm, dict[str, Any]]:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)
return algorithm, {
"last_edited": algorithm.last_edited,
Expand Down Expand Up @@ -279,8 +279,8 @@ async def get_system_card(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -326,8 +326,8 @@ async def get_algorithm_inference(
request,
)

instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand All @@ -352,8 +352,8 @@ async def get_system_card_requirements(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)

breadcrumbs = resolve_base_navigation_items(
Expand All @@ -367,13 +367,13 @@ async def get_system_card_requirements(
request,
)

requirements = fetch_requirements([requirement.urn for requirement in algorithm.system_card.requirements])
requirements = await fetch_requirements([requirement.urn for requirement in algorithm.system_card.requirements])

# Get measures that correspond to the requirements and merge them with the measuretasks
requirements_and_measures = []
for requirement in requirements:
completed_measures_count = 0
linked_measures = fetch_measures(requirement.links)
linked_measures = await fetch_measures(requirement.links)
extended_linked_measures: list[ExtendedMeasureTask] = []
for measure in linked_measures:
measure_task = find_measure_task(algorithm.system_card, measure.urn)
Expand Down Expand Up @@ -418,16 +418,16 @@ def find_requirement_task(system_card: SystemCard, requirement_urn: str) -> Requ
return None


def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure_urn: str) -> list[RequirementTask]:
async def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure_urn: str) -> list[RequirementTask]:
requirement_mapper: dict[str, RequirementTask] = {}
for requirement_task in system_card.requirements:
requirement_mapper[requirement_task.urn] = requirement_task

requirement_tasks: list[RequirementTask] = []
measure = fetch_measures([measure_urn])
measure = await fetch_measures([measure_urn])
for requirement_urn in measure[0].links:
# TODO: This is because measure are linked to too many requirement not applicable in our use case
if len(fetch_requirements([requirement_urn])) > 0:
if len(await fetch_requirements([requirement_urn])) > 0:
requirement_tasks.append(requirement_mapper[requirement_urn])

return requirement_tasks
Expand All @@ -451,7 +451,7 @@ async def get_measure(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
measure = task_registry.fetch_measures([measure_urn])
measure = await task_registry.fetch_measures([measure_urn])
measure_task = find_measure_task(algorithm.system_card, measure_urn)

context = {
Expand Down Expand Up @@ -484,9 +484,9 @@ async def update_measure_value(
measure_task.value = measure_update.measure_value # pyright: ignore [reportOptionalMemberAccess]

# update for the linked requirements the state based on all it's measures
requirement_tasks = find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks]
requirements = fetch_requirements(requirement_urns)
requirements = await fetch_requirements(requirement_urns)

for requirement in requirements:
count_completed = 0
Expand Down Expand Up @@ -520,8 +520,8 @@ async def get_system_card_data_page(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -560,8 +560,8 @@ async def get_system_card_instruments(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down Expand Up @@ -596,8 +596,8 @@ async def get_assessment_card(
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

request.state.path_variables.update({"assessment_card": assessment_card})

Expand Down Expand Up @@ -649,8 +649,8 @@ async def get_model_card(
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
request.state.path_variables.update({"model_card": model_card})
instrument_state = get_instrument_state(algorithm.system_card)
requirements_state = get_requirements_state(algorithm.system_card)
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)

tab_items = get_algorithm_details_tabs(request)

Expand Down
4 changes: 3 additions & 1 deletion amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ async def get_new(

template_files = get_template_files()

instruments = await instrument_service.fetch_instruments()

context: dict[str, Any] = {
"instruments": instrument_service.fetch_instruments(),
"instruments": instruments,
"ai_act_profile": ai_act_profile,
"breadcrumbs": breadcrumbs,
"sub_menu_items": {}, # sub_menu_items disabled for now,
Expand Down
3 changes: 2 additions & 1 deletion amt/cli/check_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -30,7 +31,7 @@ def get_tasks_by_priority(urns: list[str], system_card_path: Path) -> None:
try:
system_card = get_system_card(system_card_path)
instruments_service = create_instrument_service()
all_instruments = instruments_service.fetch_instruments()
all_instruments = asyncio.run(instruments_service.fetch_instruments())
instruments = get_requested_instruments(all_instruments, urns)
next_tasks = get_all_next_tasks(instruments, system_card)

Expand Down
16 changes: 8 additions & 8 deletions amt/clients/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class APIClient:

def __init__(self, base_url: str, max_retries: int = 3, timeout: int = 5) -> None:
self.base_url = base_url
transport = httpx.HTTPTransport(retries=max_retries)
self.client = httpx.Client(timeout=timeout, transport=transport)
transport = httpx.AsyncHTTPTransport(retries=max_retries)
self.client = httpx.AsyncClient(timeout=timeout, transport=transport)

def _make_request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
response = self.client.get(f"{self.base_url}/{endpoint}", params=params)
async def _make_request(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
response = await self.client.get(f"{self.base_url}/{endpoint}", params=params)
if response.status_code != 200:
raise AMTNotFound()
return response.json()
Expand All @@ -42,12 +42,12 @@ def __init__(self, max_retries: int = 3, timeout: int = 5) -> None:
base_url="https://task-registry.apps.digilab.network", max_retries=max_retries, timeout=timeout
)

def get_list_of_task(self, task: TaskType = TaskType.INSTRUMENTS) -> RepositoryContent:
response_data = self._make_request(f"{task.value}/")
async def get_list_of_task(self, task: TaskType = TaskType.INSTRUMENTS) -> RepositoryContent:
response_data = await self._make_request(f"{task.value}/")
return RepositoryContent.model_validate(response_data["entries"])

def get_task_by_urn(self, task_type: TaskType, urn: str, version: str = "latest") -> dict[str, Any]:
response_data = self._make_request(f"{task_type.value}/urn/{urn}", params={"version": version})
async def get_task_by_urn(self, task_type: TaskType, urn: str, version: str = "latest") -> dict[str, Any]:
response_data = await self._make_request(f"{task_type.value}/urn/{urn}", params={"version": version})
if "urn" not in response_data:
logger.exception(f"Invalid task {task_type.value} fetched: key 'urn' must occur in task {task_type.value}.")
raise AMTInstrumentError()
Expand Down
17 changes: 10 additions & 7 deletions amt/repositories/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class TaskRegistryRepository:

def __init__(self, client: TaskRegistryAPIClient) -> None:
self.client = client
self._urn_cache: dict[TaskType, list[str]] = {}

def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = None) -> list[dict[str, Any]]:
async def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = None) -> list[dict[str, Any]]:
"""
Fetches tasks (instruments, measures, etc.) with the given URNs.
Expand All @@ -29,26 +30,28 @@ def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = No
if isinstance(urns, str):
urns = [urns]

all_valid_urns: list[str] = self.fetch_urns(task_type)
all_valid_urns: list[str] = await self.fetch_urns(task_type)

if urns is None:
return [self.client.get_task_by_urn(task_type, urn) for urn in all_valid_urns]
return [await self.client.get_task_by_urn(task_type, urn) for urn in all_valid_urns]

tasks: list[dict[str, Any]] = []
for urn in urns:
# For backward compatibilty of this method we now simply ignore invalid URN's.
# We might want to refactor this later to throw exceptions when task with URN is not
# found.
try:
tasks.append(self.client.get_task_by_urn(task_type, urn))
tasks.append(await self.client.get_task_by_urn(task_type, urn))
except AMTNotFound:
logger.warning(f"Cannot find {task_type.value} with URN {urn}")

return tasks

def fetch_urns(self, task_type: TaskType) -> list[str]:
async def fetch_urns(self, task_type: TaskType) -> list[str]:
"""
Fetches all valid URNs for the given task type.
"""
content_list = self.client.get_list_of_task(task_type)
return [content.urn for content in content_list.root]
if task_type not in self._urn_cache:
content_list = await self.client.get_list_of_task(task_type)
self._urn_cache[task_type] = [content.urn for content in content_list.root]
return self._urn_cache[task_type]
2 changes: 1 addition & 1 deletion amt/services/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def create(self, algorithm_new: AlgorithmNew) -> Algorithm:
algorithm = Algorithm(name=algorithm_new.name, lifecycle=algorithm_new.lifecycle, system_card=system_card)
algorithm = await self.update(algorithm)

selected_instruments = self.instrument_service.fetch_instruments(
selected_instruments = await self.instrument_service.fetch_instruments(
[instrument.urn for instrument in algorithm.system_card.instruments]
)
for instrument in selected_instruments:
Expand Down
4 changes: 2 additions & 2 deletions amt/services/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class InstrumentsService:
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]:
async def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]:
"""
Fetches instruments with the given URNs.
If urns contains an URN that is not a valid URN of an instrument, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all instruments.
@return: List of instruments with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.INSTRUMENTS, urns)
task_data = await self.repository.fetch_tasks(TaskType.INSTRUMENTS, urns)
return [Instrument(**data) for data in task_data]


Expand Down
4 changes: 2 additions & 2 deletions amt/services/instruments_and_requirements_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def __init__(self, system_card: SystemCard) -> None:
self.system_card = system_card
self.instrument_states = []

def get_state_per_instrument(self) -> list[dict[str, int]]:
async def get_state_per_instrument(self) -> list[dict[str, int]]:
# Returns dictionary with instrument urns with value 0 or 1, if 1 then the instrument is not completed yet
# Otherwise the instrument is completed as there are not any tasks left.

urns = [instrument.urn for instrument in self.system_card.instruments]
instruments_service = create_instrument_service()
instruments = instruments_service.fetch_instruments(urns)
instruments = await instruments_service.fetch_instruments(urns)
# TODO: refactor this data structure in 3 lines below (also change in get_all_next_tasks + check_state.py)
instruments_dict = {}
instrument_states: dict[str, Any] = {}
Expand Down
4 changes: 2 additions & 2 deletions amt/services/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class MeasuresService:
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Measure]:
async def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Measure]:
"""
Fetches measures with the given URNs.
If urns contains an URN that is not a valid URN of an measure, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all measures.
@return: List of measures with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.MEASURES, urns)
task_data = await self.repository.fetch_tasks(TaskType.MEASURES, urns)
return [Measure(**data) for data in task_data]


Expand Down
4 changes: 2 additions & 2 deletions amt/services/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class RequirementsService:
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
async def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
Fetches measures with the given URNs.
If urns contains an URN that is not a valid URN of an measure, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all measures.
@return: List of measures with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.REQUIREMENTS, urns)
task_data = await self.repository.fetch_tasks(TaskType.REQUIREMENTS, urns)
return [Requirement(**data) for data in task_data]


Expand Down
Loading

0 comments on commit 8518c10

Please sign in to comment.