Skip to content

Commit

Permalink
refactor dialogue manager initialization & usage
Browse files Browse the repository at this point in the history
  • Loading branch information
alfredfrancis committed Jan 12, 2025
1 parent 4996d2a commit a2534b6
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 35 deletions.
6 changes: 3 additions & 3 deletions app/admin/train/routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.admin.intents import store
from app.bot.nlu.training import train_pipeline

Expand Down Expand Up @@ -29,9 +29,9 @@ async def get_training_data(intent_id: str):
return intent.trainingData

@router.post('/build_models')
async def build_models(request: Request, background_tasks: BackgroundTasks):
async def build_models(background_tasks: BackgroundTasks):
"""
Build Intent classification and NER Models
"""
background_tasks.add_task(train_pipeline, request.app)
background_tasks.add_task(train_pipeline)
return {"status": "training started in the background"}
18 changes: 9 additions & 9 deletions app/bot/chat/routes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from fastapi import APIRouter, HTTPException, Body, Request
from fastapi import APIRouter, HTTPException, Body, Request, Depends
from app.bot.dialogue_manager.models import ChatModel
from app.dependencies import get_dialogue_manager

router = APIRouter(prefix="/v1", tags=["bots"])

@router.post("/chat")
async def chat(request: Request, body: dict):
async def chat(request: Request, body: dict, dialogue_manager = Depends(get_dialogue_manager)):
"""
Endpoint to converse with the chatbot.
Delegates the request processing to DialogueManager.
:return: JSON response with the chatbot's reply and context.
"""
try:
# Access the dialogue manager from the fast api application state.
chat_request = ChatModel.from_json(body)
chat_response = await request.app.state.dialogue_manager.process(chat_request)
return chat_response.to_json()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {e}")

# Access the dialogue manager from the fast api application state.
chat_request = ChatModel.from_json(body)
chat_response = await dialogue_manager.process(chat_request)
return chat_response.to_json()

11 changes: 7 additions & 4 deletions app/bot/nlu/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from app.bot.nlu.intent_classifiers import IntentClassifier
from app.bot.nlu.entity_extractors import EntityExtractor
from app.admin.entities.store import list_synonyms

from app.dependencies import set_dialogue_manager
from app.config import app_config

async def train_pipeline(app):
async def train_pipeline():
"""
Initiate NLU pipeline training
:return:
Expand Down Expand Up @@ -46,7 +46,10 @@ async def train_pipeline(app):
pipeline.train(training_data, models_dir)

# recreate dialogue manager with new data
app.state.dialogue_manager = await DialogueManager.from_config()
dialogue_manager = await DialogueManager.from_config()

# update dialogue manager with new models
app.state.dialogue_manager.update_model(models_dir)
dialogue_manager.update_model(models_dir)

await set_dialogue_manager(dialogue_manager)

20 changes: 20 additions & 0 deletions app/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional
from app.bot.dialogue_manager.dialogue_manager import DialogueManager
from app.config import app_config

_dialogue_manager: Optional[DialogueManager] = None

async def get_dialogue_manager():
global _dialogue_manager
return _dialogue_manager

async def set_dialogue_manager(dialogue_manager: DialogueManager):
global _dialogue_manager
_dialogue_manager = dialogue_manager

async def init_dialogue_manager():
global _dialogue_manager
print("initializing dialogue manager")
_dialogue_manager = await DialogueManager.from_config()
_dialogue_manager.update_model(app_config.MODELS_DIR)
print("dialogue manager initialized")
32 changes: 13 additions & 19 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,25 @@
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
from app.bot.dialogue_manager.dialogue_manager import DialogueManager
from fastapi import FastAPI, APIRouter
from app.database import client as database_client
from app.dependencies import init_dialogue_manager

from app.admin.bots.routes import router as bots_router
from app.admin.entities.routes import router as entities_router
from app.admin.intents.routes import router as intents_router
from app.admin.train.routes import router as train_router
from app.bot.chat.routes import router as chat_router
from app.config import app_config

@asynccontextmanager
async def lifespan(app: FastAPI):
# initialize dialogue_manager
dialogue_manager = await DialogueManager.from_config()
dialogue_manager.update_model(app_config.MODELS_DIR)
app.state.dialogue_manager : DialogueManager = dialogue_manager
print("dialogue manager loaded")

yield

@asynccontextmanager
async def lifespan(_):
await init_dialogue_manager()
yield
database_client.close()

app = FastAPI(title="AI Chatbot Framework",lifespan=lifespan)

# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -36,10 +29,8 @@ async def lifespan(app: FastAPI):
allow_headers=["*"],
)

# Static files
app.mount("/static", StaticFiles(directory="app/static"), name="static")


@app.get("/ready")
async def ready():
return {"status": "ok"}
Expand All @@ -48,9 +39,12 @@ async def ready():
async def root():
return {"message": "Welcome to AI Chatbot Framework API"}

# admin apis
admin_router = APIRouter(prefix="/admin", tags=["admin"])
admin_router.include_router(bots_router)
admin_router.include_router(intents_router)
admin_router.include_router(entities_router)
admin_router.include_router(train_router)
app.include_router(admin_router)

app.include_router(bots_router, prefix="/admin", tags=["bots"])
app.include_router(intents_router, prefix="/admin", tags=["intents"])
app.include_router(entities_router, prefix="/admin", tags=["entities"])
app.include_router(train_router, prefix="/admin", tags=["train"])
app.include_router(chat_router, prefix="/bots", tags=["bots"])

0 comments on commit a2534b6

Please sign in to comment.