Skip to content

Commit 3e277d5

Browse files
feat: support model memoization (#156)
1 parent 6547d9c commit 3e277d5

File tree

9 files changed

+35
-31
lines changed

9 files changed

+35
-31
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4-
""" setup.py for cookiecutter-fastapi."""
4+
"""setup.py for cookiecutter-fastapi."""
55

66
from setuptools import setup
77

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
SECRET_KEY=secret
22
DEBUG=True
33
MODEL_PATH={{cookiecutter.machine_learn_model_path}}
4-
MODEL_NAME={{cookiecutter.machine_learn_model_name}}
4+
MODEL_NAME={{cookiecutter.machine_learn_model_name}}
5+
MEMOIZATION_FLAG=True
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# mpmqtcc
1+
# mpmqtcc

{{cookiecutter.project_slug}}/app/api/routes/predictor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,47 @@
11
import json
22

33
import joblib
4-
from fastapi import APIRouter, HTTPException
5-
64
from core.config import INPUT_EXAMPLE
7-
from services.predict import MachineLearningModelHandlerScore as model
5+
from fastapi import APIRouter, HTTPException
86
from models.prediction import (
97
HealthResponse,
10-
MachineLearningResponse,
118
MachineLearningDataInput,
9+
MachineLearningResponse,
1210
)
11+
from services.predict import MachineLearningModelHandlerScore as model
1312

1413
router = APIRouter()
1514

1615

17-
## Change this portion for other types of models
18-
## Add the correct type hinting when completed
1916
def get_prediction(data_point):
2017
return model.predict(data_point, load_wrapper=joblib.load, method="predict")
2118

2219

20+
def get_prediction_label(prediction):
21+
if prediction == 1:
22+
return "label ok"
23+
return "label nok"
24+
25+
2326
@router.post(
2427
"/predict",
2528
response_model=MachineLearningResponse,
2629
name="predict:get-data",
2730
)
2831
async def predict(data_input: MachineLearningDataInput):
29-
3032
if not data_input:
3133
raise HTTPException(status_code=404, detail="'data_input' argument invalid!")
3234
try:
3335
data_point = data_input.get_np_array()
3436
prediction = get_prediction(data_point)
37+
prediction_label = get_prediction_label(prediction)
3538

3639
except Exception as err:
3740
raise HTTPException(status_code=500, detail=f"Exception: {err}")
3841

39-
return MachineLearningResponse(prediction=prediction)
42+
return MachineLearningResponse(
43+
prediction=prediction, prediction_label=prediction_label
44+
)
4045

4146

4247
@router.get(
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
import sys
21
import logging
2+
import sys
33

4+
from core.logging import InterceptHandler
45
from loguru import logger
56
from starlette.config import Config
67
from starlette.datastructures import Secret
78

8-
from core.logging import InterceptHandler
9-
109
config = Config(".env")
1110

1211
API_PREFIX = "/api"
13-
VERSION = "{{cookiecutter.version}}"
12+
VERSION = "0.1.0"
1413
DEBUG: bool = config("DEBUG", cast=bool, default=False)
1514
MAX_CONNECTIONS_COUNT: int = config("MAX_CONNECTIONS_COUNT", cast=int, default=10)
1615
MIN_CONNECTIONS_COUNT: int = config("MIN_CONNECTIONS_COUNT", cast=int, default=10)
1716
SECRET_KEY: Secret = config("SECRET_KEY", cast=Secret, default="")
17+
MEMOIZATION_FLAG: bool = config("MEMOIZATION_FLAG", cast=bool, default=True)
1818

19-
PROJECT_NAME: str = config("PROJECT_NAME", default="{{cookiecutter.project_name}}")
19+
PROJECT_NAME: str = config("PROJECT_NAME", default="manu")
2020

2121
# logging configuration
2222
LOGGING_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@@ -25,6 +25,8 @@
2525
)
2626
logger.configure(handlers=[{"sink": sys.stderr, "level": LOGGING_LEVEL}])
2727

28-
MODEL_PATH = config("MODEL_PATH", default="{{cookiecutter.machine_learn_model_path}}")
29-
MODEL_NAME = config("MODEL_NAME", default="{{cookiecutter.machine_learn_model_name}}")
30-
INPUT_EXAMPLE = config("INPUT_EXAMPLE", default="{{cookiecutter.input_example_path}}")
28+
MODEL_PATH = config(
29+
"MODEL_PATH", default="/Users/arthur.dasilva/repos/arthurhenrique/n"
30+
)
31+
MODEL_NAME = config("MODEL_NAME", default="pregnancy_model_local.joblib")
32+
INPUT_EXAMPLE = config("INPUT_EXAMPLE", default="./ml/model/examples/example.json")
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
class PredictException(BaseException):
2-
...
1+
class PredictException(BaseException): ...
32

43

5-
class ModelLoadException(BaseException):
6-
...
4+
class ModelLoadException(BaseException): ...

{{cookiecutter.project_slug}}/app/core/events.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable
22

3+
import joblib
34
from fastapi import FastAPI
45

56

@@ -9,7 +10,7 @@ def preload_model():
910
"""
1011
from services.predict import MachineLearningModelHandlerScore
1112

12-
MachineLearningModelHandlerScore.get_model()
13+
MachineLearningModelHandlerScore.get_model(joblib.load)
1314

1415

1516
def create_start_app_handler(app: FastAPI) -> Callable:

{{cookiecutter.project_slug}}/app/main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from fastapi import FastAPI
2-
31
from api.routes.api import router as api_router
2+
from core.config import API_PREFIX, DEBUG, MEMOIZATION_FLAG, PROJECT_NAME, VERSION
43
from core.events import create_start_app_handler
5-
from core.config import API_PREFIX, DEBUG, PROJECT_NAME, VERSION
4+
from fastapi import FastAPI
65

76

87
def get_application() -> FastAPI:
98
application = FastAPI(title=PROJECT_NAME, debug=DEBUG, version=VERSION)
109
application.include_router(api_router, prefix=API_PREFIX)
11-
pre_load = False
12-
if pre_load:
10+
if MEMOIZATION_FLAG:
1311
application.add_event_handler("startup", create_start_app_handler(application))
1412
return application
1513

{{cookiecutter.project_slug}}/tests/test_pagination_behavior.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def test_pagenation_400_start_1():
3939

4040

4141
def test_pagenation_400_set_start_1_equals_True_and_init_as_pagenumber_as_0():
42-
"""Exception case
43-
"""
42+
"""Exception case"""
4443
with pytest.raises(Exception, match=r".* starts > 0. *"):
4544
d = pagenation(0, 20, 400, list(range(400)))

0 commit comments

Comments
 (0)