-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
40 lines (28 loc) · 1.11 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from fastapi import FastAPI, HTTPException
import mlflow.pyfunc
import pandas as pd
import json
app = FastAPI()
model_path = "data/final_salary_prediction.pkl"
class MLflowModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.model = mlflow.pyfunc.load_model(context.artifacts["model"])
def predict(self, context, model_input):
return self.model.predict(model_input)
@app.post("/predict")
async def predict(data: dict):
try:
job_id = data.get("jobId")
if job_id is None:
raise HTTPException(status_code=400, detail="jobId is required in the input data")
input_data = pd.DataFrame({"jobId": [job_id]})
prediction = model.predict(None, input_data)
predicted_salary = prediction[0]
result = {"jobId": job_id, "predicted_salary": predicted_salary}
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
model = MLflowModel(artifact_path=model_path) # path to model
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)