Skip to content

Commit c7fecb8

Browse files
Move tests to template and remove sample duplicates (#164)
1 parent 7f64940 commit c7fecb8

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import os
2+
import sys
3+
4+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "app"))
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
import pytest
3+
from fastapi.testclient import TestClient
4+
5+
from main import get_application
6+
import api.routes.predictor as predictor
7+
from core import config as app_config
8+
import main as app_main
9+
10+
11+
@pytest.fixture
12+
def client(monkeypatch):
13+
monkeypatch.setattr(app_config, "MEMOIZATION_FLAG", False)
14+
monkeypatch.setattr(app_main, "MEMOIZATION_FLAG", False)
15+
app = get_application()
16+
return TestClient(app)
17+
18+
19+
@pytest.fixture
20+
def anyio_backend():
21+
return "asyncio"
22+
23+
24+
def sample_payload():
25+
return {
26+
"feature1": 1.0,
27+
"feature2": 2.0,
28+
"feature3": 3.0,
29+
"feature4": 4.0,
30+
"feature5": 5.0,
31+
}
32+
33+
34+
@pytest.mark.anyio
35+
async def test_predict_endpoint_success(monkeypatch):
36+
monkeypatch.setattr(predictor, "get_prediction", lambda data: [1])
37+
data = predictor.MachineLearningDataInput(**sample_payload())
38+
resp = await predictor.predict(data)
39+
assert resp.prediction == 1.0
40+
assert resp.prediction_label == "label ok"
41+
42+
43+
def test_predict_endpoint_exception(client, monkeypatch):
44+
def raise_error(data):
45+
raise ValueError("fail")
46+
47+
monkeypatch.setattr(predictor, "get_prediction", raise_error)
48+
response = client.post("/api/v1/predict", json=sample_payload())
49+
assert response.status_code == 500
50+
51+
52+
def test_health_endpoint_success(client, monkeypatch, tmp_path):
53+
example = tmp_path / "example.json"
54+
example.write_text(json.dumps(sample_payload()))
55+
monkeypatch.setattr(predictor, "INPUT_EXAMPLE", str(example))
56+
monkeypatch.setattr(predictor, "get_prediction", lambda data: [0])
57+
response = client.get("/api/v1/health")
58+
assert response.status_code == 200
59+
assert response.json() == {"status": True}
60+
61+
62+
def test_health_endpoint_failure(client, monkeypatch):
63+
monkeypatch.setattr(predictor, "INPUT_EXAMPLE", "missing.json")
64+
response = client.get("/api/v1/health")
65+
assert response.status_code == 404
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import logging
2+
3+
import pytest
4+
5+
from core import config, errors
6+
7+
8+
def test_config_defaults():
9+
assert config.API_PREFIX == "/api"
10+
assert config.PROJECT_NAME == "{{cookiecutter.project_name}}"
11+
assert config.LOGGING_LEVEL in (logging.INFO, logging.DEBUG)
12+
13+
14+
def test_custom_exceptions():
15+
with pytest.raises(errors.PredictException):
16+
raise errors.PredictException("test")
17+
with pytest.raises(errors.ModelLoadException):
18+
raise errors.ModelLoadException("test")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from fastapi import FastAPI
2+
3+
from core import events
4+
from main import get_application
5+
import services.predict as predict
6+
7+
8+
def test_preload_model(monkeypatch):
9+
called = {}
10+
11+
def fake_get_model(cls, loader):
12+
called["called"] = True
13+
14+
monkeypatch.setattr(
15+
predict.MachineLearningModelHandlerScore,
16+
"get_model",
17+
classmethod(fake_get_model),
18+
)
19+
events.preload_model()
20+
assert called.get("called") is True
21+
22+
23+
def test_create_start_app_handler(monkeypatch):
24+
called = {}
25+
26+
def fake_preload():
27+
called["called"] = True
28+
29+
monkeypatch.setattr(events, "preload_model", fake_preload)
30+
app = FastAPI()
31+
handler = events.create_start_app_handler(app)
32+
handler()
33+
assert called.get("called") is True
34+
35+
36+
def test_get_application():
37+
app = get_application()
38+
assert isinstance(app, FastAPI)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
3+
import pytest
4+
5+
import services.predict as predict
6+
7+
8+
class DummyModel:
9+
def predict(self, data):
10+
return [42]
11+
12+
13+
class DummyScaler:
14+
def transform(self, data):
15+
return data
16+
17+
18+
def test_predict_success(monkeypatch):
19+
predict.MachineLearningModelHandlerScore.model = DummyModel()
20+
result = predict.MachineLearningModelHandlerScore.predict([[1]])
21+
assert result == [42]
22+
23+
24+
def test_predict_missing_method(monkeypatch):
25+
predict.MachineLearningModelHandlerScore.model = {"model": object(), "scaler": DummyScaler()}
26+
with pytest.raises(predict.PredictException):
27+
predict.MachineLearningModelHandlerScore.predict([[1]])
28+
29+
30+
def test_get_model_caches(monkeypatch):
31+
predict.MachineLearningModelHandlerScore.model = None
32+
monkeypatch.setattr(
33+
predict.MachineLearningModelHandlerScore,
34+
"load",
35+
staticmethod(lambda loader: {"model": DummyModel(), "scaler": DummyScaler()}),
36+
)
37+
model = predict.MachineLearningModelHandlerScore.get_model(lambda path: None)
38+
assert model["model"].__class__ is DummyModel
39+
model2 = predict.MachineLearningModelHandlerScore.get_model(None)
40+
assert model2 is model
41+
42+
43+
def test_load_model_success(tmp_path, monkeypatch):
44+
dummy = tmp_path / "model.joblib"
45+
dummy.write_text("data")
46+
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
47+
monkeypatch.setattr(predict, "MODEL_NAME", "model.joblib")
48+
49+
def fake_loader(path):
50+
assert os.path.exists(path)
51+
return {"model": DummyModel(), "scaler": DummyScaler()}
52+
53+
model = predict.MachineLearningModelHandlerScore.load(fake_loader)
54+
assert model["model"].__class__ is DummyModel
55+
56+
57+
def test_load_model_missing(tmp_path, monkeypatch):
58+
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
59+
monkeypatch.setattr(predict, "MODEL_NAME", "missing.joblib")
60+
with pytest.raises(FileNotFoundError):
61+
predict.MachineLearningModelHandlerScore.load(lambda p: None)
62+
63+
64+
def test_load_model_empty(tmp_path, monkeypatch):
65+
dummy = tmp_path / "model.joblib"
66+
dummy.write_text("data")
67+
monkeypatch.setattr(predict, "MODEL_PATH", str(tmp_path))
68+
monkeypatch.setattr(predict, "MODEL_NAME", "model.joblib")
69+
70+
def fake_loader(path):
71+
return None
72+
73+
with pytest.raises(predict.ModelLoadException):
74+
predict.MachineLearningModelHandlerScore.load(fake_loader)

0 commit comments

Comments
 (0)