11import json
2+ from pathlib import Path
23
34import joblib
45from core .config import INPUT_EXAMPLE
56from fastapi import APIRouter , HTTPException
7+ from fastapi .concurrency import run_in_threadpool
68from models .prediction import (
79 HealthResponse ,
810 MachineLearningDataInput ,
@@ -33,11 +35,14 @@ async def predict(data_input: MachineLearningDataInput):
3335 raise HTTPException (status_code = 404 , detail = "'data_input' argument invalid!" )
3436 try :
3537 data_point = data_input .get_np_array ()
36- prediction = get_prediction (data_point )
38+ prediction = await run_in_threadpool (get_prediction , data_point )
39+ try :
40+ prediction = float (prediction [0 ])
41+ except (TypeError , IndexError , KeyError ):
42+ prediction = float (prediction )
3743 prediction_label = get_prediction_label (prediction )
38-
3944 except Exception as err :
40- raise HTTPException (status_code = 500 , detail = f"Exception: { err } " )
45+ raise HTTPException (status_code = 500 , detail = f"Exception: { err } " ) from err
4146
4247 return MachineLearningResponse (
4348 prediction = prediction , prediction_label = prediction_label
@@ -50,14 +55,11 @@ async def predict(data_input: MachineLearningDataInput):
5055 name = "health:get-data" ,
5156)
5257async def health ():
53- is_health = False
5458 try :
55- test_input = MachineLearningDataInput (
56- ** json .loads (open (INPUT_EXAMPLE , "r" ).read ())
57- )
59+ content = await run_in_threadpool (Path (INPUT_EXAMPLE ).read_text )
60+ test_input = MachineLearningDataInput (** json .loads (content ))
5861 test_point = test_input .get_np_array ()
59- get_prediction (test_point )
60- is_health = True
61- return HealthResponse (status = is_health )
62+ await run_in_threadpool (get_prediction , test_point )
63+ return HealthResponse (status = True )
6264 except Exception :
6365 raise HTTPException (status_code = 404 , detail = "Unhealthy" )
0 commit comments