Skip to content

Commit e6a85fb

Browse files
test: cover predictor routes
1 parent c7492d9 commit e6a85fb

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

{{cookiecutter.project_slug}}/app/api/routes/predictor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
from pathlib import Path
23

34
import joblib
45
from core.config import INPUT_EXAMPLE
56
from fastapi import APIRouter, HTTPException
7+
from fastapi.concurrency import run_in_threadpool
68
from models.prediction import (
79
HealthResponse,
810
MachineLearningDataInput,
@@ -33,11 +35,14 @@ async def predict(data_input: MachineLearningDataInput):
3335
raise HTTPException(status_code=404, detail="'data_input' argument invalid!")
3436
try:
3537
data_point = data_input.get_np_array()
36-
prediction = get_prediction(data_point)
38+
prediction = await run_in_threadpool(get_prediction, data_point)
39+
try:
40+
prediction = float(prediction[0])
41+
except (TypeError, IndexError, KeyError):
42+
prediction = float(prediction)
3743
prediction_label = get_prediction_label(prediction)
38-
3944
except Exception as err:
40-
raise HTTPException(status_code=500, detail=f"Exception: {err}")
45+
raise HTTPException(status_code=500, detail=f"Exception: {err}") from err
4146

4247
return MachineLearningResponse(
4348
prediction=prediction, prediction_label=prediction_label
@@ -50,14 +55,11 @@ async def predict(data_input: MachineLearningDataInput):
5055
name="health:get-data",
5156
)
5257
async def health():
53-
is_health = False
5458
try:
55-
test_input = MachineLearningDataInput(
56-
**json.loads(open(INPUT_EXAMPLE, "r").read())
57-
)
59+
content = await run_in_threadpool(Path(INPUT_EXAMPLE).read_text)
60+
test_input = MachineLearningDataInput(**json.loads(content))
5861
test_point = test_input.get_np_array()
59-
get_prediction(test_point)
60-
is_health = True
61-
return HealthResponse(status=is_health)
62+
await run_in_threadpool(get_prediction, test_point)
63+
return HealthResponse(status=True)
6264
except Exception:
6365
raise HTTPException(status_code=404, detail="Unhealthy")

{{cookiecutter.project_slug}}/app/models/prediction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
class MachineLearningResponse(BaseModel):
77
prediction: float
8+
prediction_label: str
89

910

1011
class HealthResponse(BaseModel):
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import json
2+
import sys
3+
from pathlib import Path
4+
5+
import pytest
6+
from fastapi.testclient import TestClient
7+
8+
sys.path.append(str(Path(__file__).resolve().parents[1] / "app"))
9+
10+
from api.routes import predictor
11+
from main import app
12+
13+
14+
client = TestClient(app)
15+
16+
17+
def sample_input():
18+
example_path = Path(__file__).resolve().parents[1] / "ml" / "model" / "examples" / "example.json"
19+
return json.loads(example_path.read_text())
20+
21+
22+
def test_predict_endpoint(monkeypatch):
23+
monkeypatch.setattr(predictor, "get_prediction", lambda data_point: [1.0])
24+
response = client.post("/api/v1/predict", json=sample_input())
25+
assert response.status_code == 200
26+
assert response.json() == {"prediction": 1.0, "prediction_label": "label ok"}
27+
28+
29+
def test_health_endpoint(monkeypatch):
30+
monkeypatch.setattr(predictor, "get_prediction", lambda data_point: [1.0])
31+
monkeypatch.setattr(
32+
predictor,
33+
"INPUT_EXAMPLE",
34+
str(Path(__file__).resolve().parents[1] / "ml" / "model" / "examples" / "example.json"),
35+
)
36+
response = client.get("/api/v1/health")
37+
assert response.status_code == 200
38+
assert response.json() == {"status": True}

0 commit comments

Comments
 (0)