Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions {{cookiecutter.project_slug}}/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "app"))
65 changes: 65 additions & 0 deletions {{cookiecutter.project_slug}}/tests/test_api_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import pytest
from fastapi.testclient import TestClient

from main import get_application
import api.routes.predictor as predictor
from core import config as app_config
import main as app_main


@pytest.fixture
def client(monkeypatch):
monkeypatch.setattr(app_config, "MEMOIZATION_FLAG", False)
monkeypatch.setattr(app_main, "MEMOIZATION_FLAG", False)
app = get_application()
return TestClient(app)


@pytest.fixture
def anyio_backend():
return "asyncio"


def sample_payload():
return {
"feature1": 1.0,
"feature2": 2.0,
"feature3": 3.0,
"feature4": 4.0,
"feature5": 5.0,
}


@pytest.mark.anyio
async def test_predict_endpoint_success(monkeypatch):
monkeypatch.setattr(predictor, "get_prediction", lambda data: [1])
data = predictor.MachineLearningDataInput(**sample_payload())
resp = await predictor.predict(data)
assert resp.prediction == 1.0
assert resp.prediction_label == "label ok"


def test_predict_endpoint_exception(client, monkeypatch):
def raise_error(data):
raise ValueError("fail")

monkeypatch.setattr(predictor, "get_prediction", raise_error)
response = client.post("/api/v1/predict", json=sample_payload())
assert response.status_code == 500


def test_health_endpoint_success(client, monkeypatch, tmp_path):
example = tmp_path / "example.json"
example.write_text(json.dumps(sample_payload()))
monkeypatch.setattr(predictor, "INPUT_EXAMPLE", str(example))
monkeypatch.setattr(predictor, "get_prediction", lambda data: [0])
response = client.get("/api/v1/health")
assert response.status_code == 200
assert response.json() == {"status": True}


def test_health_endpoint_failure(client, monkeypatch):
monkeypatch.setattr(predictor, "INPUT_EXAMPLE", "missing.json")
response = client.get("/api/v1/health")
assert response.status_code == 404
18 changes: 18 additions & 0 deletions {{cookiecutter.project_slug}}/tests/test_config_and_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging

import pytest

from core import config, errors


def test_config_defaults():
assert config.API_PREFIX == "/api"
assert config.PROJECT_NAME == "{{cookiecutter.project_name}}"
assert config.LOGGING_LEVEL in (logging.INFO, logging.DEBUG)


def test_custom_exceptions():
with pytest.raises(errors.PredictException):
raise errors.PredictException("test")
with pytest.raises(errors.ModelLoadException):
raise errors.ModelLoadException("test")
38 changes: 38 additions & 0 deletions {{cookiecutter.project_slug}}/tests/test_events_and_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from fastapi import FastAPI

from core import events
from main import get_application
import services.predict as predict


def test_preload_model(monkeypatch):
called = {}

def fake_get_model(cls, loader):
called["called"] = True

monkeypatch.setattr(
predict.MachineLearningModelHandlerScore,
"get_model",
classmethod(fake_get_model),
)
events.preload_model()
assert called.get("called") is True


def test_create_start_app_handler(monkeypatch):
called = {}

def fake_preload():
called["called"] = True

monkeypatch.setattr(events, "preload_model", fake_preload)
app = FastAPI()
handler = events.create_start_app_handler(app)
handler()
assert called.get("called") is True


def test_get_application():
app = get_application()
assert isinstance(app, FastAPI)
74 changes: 74 additions & 0 deletions {{cookiecutter.project_slug}}/tests/test_predict_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

import pytest

import services.predict as predict


class DummyModel:
def predict(self, data):
return [42]


class DummyScaler:
def transform(self, data):
return data


def test_predict_success(monkeypatch):
predict.MachineLearningModelHandlerScore.model = DummyModel()
result = predict.MachineLearningModelHandlerScore.predict([[1]])
assert result == [42]


def test_predict_missing_method(monkeypatch):
predict.MachineLearningModelHandlerScore.model = {"model": object(), "scaler": DummyScaler()}
with pytest.raises(predict.PredictException):
predict.MachineLearningModelHandlerScore.predict([[1]])


def test_get_model_caches(monkeypatch):
predict.MachineLearningModelHandlerScore.model = None
monkeypatch.setattr(
predict.MachineLearningModelHandlerScore,
"load",
staticmethod(lambda loader: {"model": DummyModel(), "scaler": DummyScaler()}),
)
model = predict.MachineLearningModelHandlerScore.get_model(lambda path: None)
assert model["model"].__class__ is DummyModel
model2 = predict.MachineLearningModelHandlerScore.get_model(None)
assert model2 is model


def test_load_model_success(tmp_path, monkeypatch):
dummy = tmp_path / "model.joblib"
dummy.write_text("data")
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
monkeypatch.setattr(predict, "MODEL_NAME", "model.joblib")

def fake_loader(path):
assert os.path.exists(path)
return {"model": DummyModel(), "scaler": DummyScaler()}

model = predict.MachineLearningModelHandlerScore.load(fake_loader)
assert model["model"].__class__ is DummyModel


def test_load_model_missing(tmp_path, monkeypatch):
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
monkeypatch.setattr(predict, "MODEL_NAME", "missing.joblib")
with pytest.raises(FileNotFoundError):
predict.MachineLearningModelHandlerScore.load(lambda p: None)


def test_load_model_empty(tmp_path, monkeypatch):
dummy = tmp_path / "model.joblib"
dummy.write_text("data")
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
monkeypatch.setattr(predict, "MODEL_NAME", "model.joblib")

def fake_loader(path):
return None

with pytest.raises(predict.ModelLoadException):
predict.MachineLearningModelHandlerScore.load(fake_loader)