Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 15, 2024
1 parent 67a26fd commit be429e2
Show file tree
Hide file tree
Showing 8 changed files with 8,313 additions and 0 deletions.
51 changes: 51 additions & 0 deletions amt/schema/requirement.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,54 @@
from enum import Enum

from pydantic import Field

from amt.schema.shared import BaseModel


class TypeEnum(Enum):
AI_systeem = "AI-systeem"
AI_systeem_voor_algemene_doeleinden = "AI-systeem voor algemene doeleinden"
AI_model_voor_algemene_doeleinden = "AI-model voor algemene doeleinden"
algoritme = "algoritme"


class OpenSourceEnum(Enum):
open_source = "open-source"
geen_open_source = "geen open-source"


class RiskCategoryEnum(Enum):
geen_hoog_risico_AI = "geen hoog-risico AI"
hoog_risico_AI = "hoog-risico AI"
verboden_AI = "verboden AI"


class SystemicRiskEnum(Enum):
systeemrisico = "systeemrisico"
geen_systeemrisico = "geen systeemrisico"


class TransparencyObligation(Enum):
transparantieverplichting = "transparantieverplichting"
geen_transparantieverplichting = "geen transparantieverplichting"


class RoleEnum(Enum):
aanbieder = "aanbieder"
gebruiksverantwoordelijke = "gebruiksverantwoordelijke"
importeur = "importeur"
distributeur = "distributeur"


class AiActProfileItem(BaseModel):
type: list[TypeEnum]
open_source: list[OpenSourceEnum] | None = None
risk_category: list[RiskCategoryEnum]
systemic_risk: list[SystemicRiskEnum] | None = None
transparency_obligations: list[TransparencyObligation] | None = None
role: list[RoleEnum]


class RequirementBase(BaseModel):
urn: str

Expand All @@ -16,3 +62,8 @@ class Requirement(RequirementBase):
name: str
description: str
links: list[str] = Field(default=[])
ai_act_profile: list[AiActProfileItem]
always_applicable: int = Field(
...,
description="1 if requirements applies to every system, 0 if only for specific systems",
)
6 changes: 6 additions & 0 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
logger = logging.getLogger(__name__)


def get_requirements(ai_act_profile: AiActProfile) -> list[RequirementTask]:
requirements: list[RequirementTask] = []

return requirements


def get_requirements_and_measures(
ai_act_profile: AiActProfile,
) -> tuple[
Expand Down
34 changes: 34 additions & 0 deletions tests/fixtures/vcr_cassettes/test_fetch_task_with_invalid_urn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
interactions:
- request:
body: ""
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
host:
- task-registry.apps.digilab.network
user-agent:
- python-httpx/0.27.2
method: GET
uri: https://task-registry.apps.digilab.network/instruments/urn/invalid?version=latest
response:
body:
string: '{"detail":"invalid urn: invalid"}'
headers:
Connection:
- keep-alive
Content-Length:
- "33"
Content-Type:
- application/json
Date:
- Fri, 15 Nov 2024 14:26:47 GMT
Strict-Transport-Security:
- max-age=31536000; includeSubDomains
status:
code: 400
message: Bad Request
version: 1
291 changes: 291 additions & 0 deletions tests/fixtures/vcr_cassettes/test_fetch_task_with_urn.yml

Large diffs are not rendered by default.

724 changes: 724 additions & 0 deletions tests/fixtures/vcr_cassettes/test_fetch_task_with_urns.yml

Large diffs are not rendered by default.

Large diffs are not rendered by default.

6,747 changes: 6,747 additions & 0 deletions tests/fixtures/vcr_cassettes/test_fetch_tasks_all.yml

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions tests/repositories/test_task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import pytest
from pytest_httpx import HTTPXMock
import vcr # type: ignore
from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTInstrumentError
from amt.models.task import Task
from amt.repositories.task_registry import TaskRegistryRepository
from tests.constants import TASK_REGISTRY_AIIA_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD

# TODO(berry): made payloads to a better location


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_tasks_all.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_tasks_all():
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)

# when
instrument_result = await repository.fetch_tasks(TaskType.INSTRUMENTS)
requirement_result = await repository.fetch_tasks(TaskType.REQUIREMENTS)
measure_result = await repository.fetch_tasks(TaskType.MEASURES)

# then
assert len(instrument_result) == 4
assert len(requirement_result) == 60
assert len(measure_result) == 70


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_task_with_urn.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_task_with_urn():
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
urn = "urn:nl:aivt:tr:iama:1.0"

# when
result = await repository.fetch_tasks(TaskType.INSTRUMENTS, urns=urn)

# then
assert len(result) == 1
assert "urn" in result[0]
assert result[0]["urn"] == urn


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_task_with_urns.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_task_with_urns():
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
urns = ["urn:nl:aivt:tr:iama:1.0", "urn:nl:aivt:tr:aiia:1.0"]

# when
result = await repository.fetch_tasks(TaskType.INSTRUMENTS, urns=urns)

# then
assert len(result) == 2
assert "urn" in result[0]
assert result[0]["urn"] == urns[0]
assert "urn" in result[1]
assert result[1]["urn"] == urns[1]


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_task_with_invalid_urn.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_task_with_invalid_urn():
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
urn = "invalid"

# when
result = await repository.fetch_tasks(TaskType.INSTRUMENTS, urns=urn)

# then
assert len(result) == 0


@vcr.use_cassette("tests/fixtures/vcr_cassettes/test_fetch_task_with_valid_and_invalid_urn.yml") # type: ignore
@pytest.mark.asyncio
async def test_fetch_task_with_valid_and_invalid_urn():
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
urns = ["urn:nl:aivt:tr:iama:1.0", "invalid"]

# when
result = await repository.fetch_tasks(TaskType.INSTRUMENTS, urns=urns)

# then
assert len(result) == 1
assert "urn" in result[0]
assert result[0]["urn"] == urns[0]


@pytest.mark.asyncio
async def test_fetch_tasks_invalid_response(httpx_mock: HTTPXMock):
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
urn = "urn:nl:aivt:tr:iama:1.0"

httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest",
content=b'{"test": 1}',
)

# then
with pytest.raises(AMTInstrumentError):
await repository.fetch_tasks(TaskType.INSTRUMENTS, urn)


@pytest.mark.asyncio
async def test_fetch_tasks_valid_and_invalid_response(httpx_mock: HTTPXMock):
# given
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)

httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/",
content=TASK_REGISTRY_LIST_PAYLOAD.encode(),
)
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:aiia:1.0?version=latest",
content=TASK_REGISTRY_AIIA_CONTENT_PAYLOAD.encode(),
)
httpx_mock.add_response(
url="https://task-registry.apps.digilab.network/instruments/urn/urn:nl:aivt:tr:iama:1.0?version=latest",
content=b'{"test": 1}',
)

# then
with pytest.raises(AMTInstrumentError):
await repository.fetch_tasks(TaskType.INSTRUMENTS)

0 comments on commit be429e2

Please sign in to comment.