Skip to content

Commit 3a3922c

Browse files
test: add request logging test
1 parent fe848a0 commit 3a3922c

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
3+
import pytest
4+
from sqlalchemy import create_engine
5+
from sqlalchemy.orm import sessionmaker
6+
7+
from api.routes import predictor
8+
from db import Base
9+
from models.log import RequestLog
10+
from models.prediction import MachineLearningDataInput
11+
12+
13+
@pytest.fixture
14+
def anyio_backend():
15+
return "asyncio"
16+
17+
@pytest.mark.anyio
18+
async def test_predict_logs_request_response(monkeypatch):
19+
engine = create_engine("sqlite:///:memory:")
20+
TestingSessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
21+
Base.metadata.create_all(bind=engine)
22+
monkeypatch.setattr(predictor, "SessionLocal", TestingSessionLocal)
23+
monkeypatch.setattr(predictor, "get_prediction", lambda data: [1])
24+
25+
payload = {
26+
"feature1": 1.0,
27+
"feature2": 2.0,
28+
"feature3": 3.0,
29+
"feature4": 4.0,
30+
"feature5": 5.0,
31+
}
32+
data = MachineLearningDataInput(**payload)
33+
34+
response = await predictor.predict(data)
35+
assert response.prediction == 1.0
36+
37+
db = TestingSessionLocal()
38+
logs = db.query(RequestLog).all()
39+
assert len(logs) == 1
40+
log = logs[0]
41+
assert json.loads(log.request) == data.model_dump()
42+
assert json.loads(log.response) == response.model_dump()
43+
db.close()

0 commit comments

Comments
 (0)