Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Always run `pytest {{cookiecutter.project_slug}}/tests` and ensure all tests pass after any code changes.
- Always add tests, keep your branch rebased instead of merged, and adhere to the commit message recommendations from cbea.ms/git-commit.
1 change: 1 addition & 0 deletions {{cookiecutter.project_slug}}/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ DEBUG=True
MODEL_PATH={{cookiecutter.machine_learn_model_path}}
MODEL_NAME={{cookiecutter.machine_learn_model_name}}
MEMOIZATION_FLAG=False
DATABASE_URL=postgresql://postgres:postgres@db:5432/app
22 changes: 21 additions & 1 deletion {{cookiecutter.project_slug}}/app/api/routes/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from core.config import INPUT_EXAMPLE
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from db import SessionLocal
from models.log import RequestLog
from models.prediction import (
HealthResponse,
MachineLearningDataInput,
Expand Down Expand Up @@ -44,10 +46,28 @@ async def predict(data_input: MachineLearningDataInput):
except Exception as err:
raise HTTPException(status_code=500, detail=f"Exception: {err}") from err

return MachineLearningResponse(
response = MachineLearningResponse(
prediction=prediction, prediction_label=prediction_label
)

try:
db = SessionLocal()
log = RequestLog(
request=json.dumps(data_input.model_dump()),
response=json.dumps(response.model_dump()),
)
db.add(log)
db.commit()
except Exception:
pass
finally:
try:
db.close()
except Exception:
pass

return response


@router.get(
"/health",
Expand Down
1 change: 1 addition & 0 deletions {{cookiecutter.project_slug}}/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MIN_CONNECTIONS_COUNT: int = config("MIN_CONNECTIONS_COUNT", cast=int, default=10)
SECRET_KEY: Secret = config("SECRET_KEY", cast=Secret, default="")
MEMOIZATION_FLAG: bool = config("MEMOIZATION_FLAG", cast=bool, default=True)
DATABASE_URL: str = config("DATABASE_URL", default="sqlite:///./app.db")

PROJECT_NAME: str = config("PROJECT_NAME", default="{{cookiecutter.project_name}}")

Expand Down
6 changes: 5 additions & 1 deletion {{cookiecutter.project_slug}}/app/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import joblib
from fastapi import FastAPI
from core.config import MEMOIZATION_FLAG
from db import Base, engine


def preload_model():
Expand All @@ -15,6 +17,8 @@ def preload_model():

def create_start_app_handler(app: FastAPI) -> Callable:
def start_app() -> None:
preload_model()
if MEMOIZATION_FLAG:
preload_model()
Base.metadata.create_all(bind=engine)

return start_app
8 changes: 8 additions & 0 deletions {{cookiecutter.project_slug}}/app/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

from core.config import DATABASE_URL

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
3 changes: 1 addition & 2 deletions {{cookiecutter.project_slug}}/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
def get_application() -> FastAPI:
application = FastAPI(title=PROJECT_NAME, debug=DEBUG, version=VERSION)
application.include_router(api_router, prefix=API_PREFIX)
if MEMOIZATION_FLAG:
application.add_event_handler("startup", create_start_app_handler(application))
application.add_event_handler("startup", create_start_app_handler(application))
return application


Expand Down
11 changes: 11 additions & 0 deletions {{cookiecutter.project_slug}}/app/models/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy import Column, Integer, Text

from db import Base


class RequestLog(Base):
__tablename__ = "request_logs"

id = Column(Integer, primary_key=True, index=True)
request = Column(Text, nullable=False)
response = Column(Text, nullable=False)
17 changes: 16 additions & 1 deletion {{cookiecutter.project_slug}}/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,19 @@ services:
command: uvicorn main:app --reload --host 0.0.0.0 --port 8080
volumes:
- ./app:/app/
- ./ml/model/:/app/ml/model/
- ./ml/model/:/app/ml/model/
depends_on:
- db
db:
image: postgres:16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: app
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data

volumes:
postgres_data:
4 changes: 3 additions & 1 deletion {{cookiecutter.project_slug}}/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ dependencies = [
"joblib>=1.2.0",
"scikit-learn>=1.1.3",
"pandas>=2.2.3",
"httpx>=0.27.0"
"httpx>=0.27.0",
"sqlalchemy>=2.0.0",
"psycopg2-binary>=2.9.0"
]

[project.optional-dependencies]
Expand Down
43 changes: 43 additions & 0 deletions {{cookiecutter.project_slug}}/tests/test_request_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from api.routes import predictor
from db import Base
from models.log import RequestLog
from models.prediction import MachineLearningDataInput


@pytest.fixture
def anyio_backend():
return "asyncio"

@pytest.mark.anyio
async def test_predict_logs_request_response(monkeypatch):
engine = create_engine("sqlite:///:memory:")
TestingSessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
Base.metadata.create_all(bind=engine)
monkeypatch.setattr(predictor, "SessionLocal", TestingSessionLocal)
monkeypatch.setattr(predictor, "get_prediction", lambda data: [1])

payload = {
"feature1": 1.0,
"feature2": 2.0,
"feature3": 3.0,
"feature4": 4.0,
"feature5": 5.0,
}
data = MachineLearningDataInput(**payload)

response = await predictor.predict(data)
assert response.prediction == 1.0

db = TestingSessionLocal()
logs = db.query(RequestLog).all()
assert len(logs) == 1
log = logs[0]
assert json.loads(log.request) == data.model_dump()
assert json.loads(log.response) == response.model_dump()
db.close()
Loading