Skip to content

Commit 29c169a

Browse files
feat: support model memoization
1 parent 73500f0 commit 29c169a

File tree

11 files changed

+45
-40
lines changed

11 files changed

+45
-40
lines changed

.github/workflows/ci.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- name: Create Virtual Environment for UV
3232
run: |
3333
cd name-of-the-project
34-
uv venv .venv # Criando ambiente virtual
34+
make install
3535
echo "VIRTUAL_ENV=$(pwd)/.venv" >> $GITHUB_ENV
3636
echo "$(pwd)/.venv/bin" >> $GITHUB_PATH
3737
@@ -43,4 +43,4 @@ jobs:
4343
- name: Run Tests
4444
run: |
4545
cd name-of-the-project
46-
make test
46+
uv run pytest tests -vv --show-capture=all

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4-
""" setup.py for cookiecutter-fastapi."""
4+
"""setup.py for cookiecutter-fastapi."""
55

66
from setuptools import setup
77

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
SECRET_KEY=secret
22
DEBUG=True
33
MODEL_PATH={{cookiecutter.machine_learn_model_path}}
4-
MODEL_NAME={{cookiecutter.machine_learn_model_name}}
4+
MODEL_NAME={{cookiecutter.machine_learn_model_name}}
5+
MEMOIZATION_FLAG=True

{{cookiecutter.project_slug}}/Makefile

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ endif
2121

2222
all: clean test install run deploy down
2323

24-
test:
25-
uv pip install -e ".[dev]"
24+
venv:
25+
uv venv .venv
26+
27+
test: venv
2628
uv run pytest tests -vv --show-capture=all
2729

28-
install: generate_dot_env
29-
pip install --upgrade pip
30-
pip install uv
30+
install: generate_dot_env venv
31+
pip install uv --break-system-packages
3132
uv pip install -e ".[dev]"
3233

33-
run:
34-
PYTHONPATH=app/ uvicorn main:app --reload --host 0.0.0.0 --port 8080
34+
run: venv
35+
PYTHONPATH=app/ uv run uvicorn main:app --reload --host 0.0.0.0 --port 8080
3536

3637
deploy: generate_dot_env
3738
docker-compose build
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# mpmqtcc
1+
# mpmqtcc

{{cookiecutter.project_slug}}/app/api/routes/predictor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,47 @@
11
import json
22

33
import joblib
4-
from fastapi import APIRouter, HTTPException
5-
64
from core.config import INPUT_EXAMPLE
7-
from services.predict import MachineLearningModelHandlerScore as model
5+
from fastapi import APIRouter, HTTPException
86
from models.prediction import (
97
HealthResponse,
10-
MachineLearningResponse,
118
MachineLearningDataInput,
9+
MachineLearningResponse,
1210
)
11+
from services.predict import MachineLearningModelHandlerScore as model
1312

1413
router = APIRouter()
1514

1615

17-
## Change this portion for other types of models
18-
## Add the correct type hinting when completed
1916
def get_prediction(data_point):
2017
return model.predict(data_point, load_wrapper=joblib.load, method="predict")
2118

2219

20+
def get_prediction_label(prediction):
21+
if prediction == 1:
22+
return "label ok"
23+
return "label nok"
24+
25+
2326
@router.post(
2427
"/predict",
2528
response_model=MachineLearningResponse,
2629
name="predict:get-data",
2730
)
2831
async def predict(data_input: MachineLearningDataInput):
29-
3032
if not data_input:
3133
raise HTTPException(status_code=404, detail="'data_input' argument invalid!")
3234
try:
3335
data_point = data_input.get_np_array()
3436
prediction = get_prediction(data_point)
37+
prediction_label = get_prediction_label(prediction)
3538

3639
except Exception as err:
3740
raise HTTPException(status_code=500, detail=f"Exception: {err}")
3841

39-
return MachineLearningResponse(prediction=prediction)
42+
return MachineLearningResponse(
43+
prediction=prediction, prediction_label=prediction_label
44+
)
4045

4146

4247
@router.get(
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
import sys
21
import logging
2+
import sys
33

4+
from core.logging import InterceptHandler
45
from loguru import logger
56
from starlette.config import Config
67
from starlette.datastructures import Secret
78

8-
from core.logging import InterceptHandler
9-
109
config = Config(".env")
1110

1211
API_PREFIX = "/api"
13-
VERSION = "{{cookiecutter.version}}"
12+
VERSION = "0.1.0"
1413
DEBUG: bool = config("DEBUG", cast=bool, default=False)
1514
MAX_CONNECTIONS_COUNT: int = config("MAX_CONNECTIONS_COUNT", cast=int, default=10)
1615
MIN_CONNECTIONS_COUNT: int = config("MIN_CONNECTIONS_COUNT", cast=int, default=10)
1716
SECRET_KEY: Secret = config("SECRET_KEY", cast=Secret, default="")
17+
MEMOIZATION_FLAG: bool = config("MEMOIZATION_FLAG", cast=bool, default=True)
1818

19-
PROJECT_NAME: str = config("PROJECT_NAME", default="{{cookiecutter.project_name}}")
19+
PROJECT_NAME: str = config("PROJECT_NAME", default="manu")
2020

2121
# logging configuration
2222
LOGGING_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@@ -25,6 +25,8 @@
2525
)
2626
logger.configure(handlers=[{"sink": sys.stderr, "level": LOGGING_LEVEL}])
2727

28-
MODEL_PATH = config("MODEL_PATH", default="{{cookiecutter.machine_learn_model_path}}")
29-
MODEL_NAME = config("MODEL_NAME", default="{{cookiecutter.machine_learn_model_name}}")
30-
INPUT_EXAMPLE = config("INPUT_EXAMPLE", default="{{cookiecutter.input_example_path}}")
28+
MODEL_PATH = config(
29+
"MODEL_PATH", default="/Users/arthur.dasilva/repos/arthurhenrique/n"
30+
)
31+
MODEL_NAME = config("MODEL_NAME", default="pregnancy_model_local.joblib")
32+
INPUT_EXAMPLE = config("INPUT_EXAMPLE", default="./ml/model/examples/example.json")
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
class PredictException(BaseException):
2-
...
1+
class PredictException(BaseException): ...
32

43

5-
class ModelLoadException(BaseException):
6-
...
4+
class ModelLoadException(BaseException): ...

{{cookiecutter.project_slug}}/app/core/events.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable
22

3+
import joblib
34
from fastapi import FastAPI
45

56

@@ -9,7 +10,7 @@ def preload_model():
910
"""
1011
from services.predict import MachineLearningModelHandlerScore
1112

12-
MachineLearningModelHandlerScore.get_model()
13+
MachineLearningModelHandlerScore.get_model(joblib.load)
1314

1415

1516
def create_start_app_handler(app: FastAPI) -> Callable:

{{cookiecutter.project_slug}}/app/main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from fastapi import FastAPI
2-
31
from api.routes.api import router as api_router
2+
from core.config import API_PREFIX, DEBUG, MEMOIZATION_FLAG, PROJECT_NAME, VERSION
43
from core.events import create_start_app_handler
5-
from core.config import API_PREFIX, DEBUG, PROJECT_NAME, VERSION
4+
from fastapi import FastAPI
65

76

87
def get_application() -> FastAPI:
98
application = FastAPI(title=PROJECT_NAME, debug=DEBUG, version=VERSION)
109
application.include_router(api_router, prefix=API_PREFIX)
11-
pre_load = False
12-
if pre_load:
10+
if MEMOIZATION_FLAG:
1311
application.add_event_handler("startup", create_start_app_handler(application))
1412
return application
1513

0 commit comments

Comments
 (0)