Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance RAG with hybrid search #62

Merged
merged 11 commits into from
Jun 28, 2024
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ OPENAI_API_KEY=
ANTHROPIC_API_KEY=

# Embedding model. See the list of supported models: https://qdrant.github.io/fastembed/examples/Supported_Models/
EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
SPARSE_EMBEDDING_MODEL=prithivida/Splade_PP_en_v1


# Langsmith: For llm observability
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ RAG is a technique for augmenting your agents' knowledge with additional data. A

#### Customising embedding models

By default, Tribe uses `BAAI/bge-small-en-v1.5`, which is a light and fast English embedding model that is better than `OpenAI Ada-002`. If your documents are multilingual or require image embedding, you may want to use another embedding model. You can easily do this by changing `EMBEDDING_MODEL` in your `.env` file:
By default, Tribe uses `BAAI/bge-small-en-v1.5`, which is a light and fast English embedding model that is better than `OpenAI Ada-002`. If your documents are multilingual or require image embedding, you may want to use another embedding model. You can easily do this by changing `DENSE_EMBEDDING_MODEL` in your `.env` file:

```bash
# See the list of supported models: https://qdrant.github.io/fastembed/examples/Supported_Models/
EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this
DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this
```

> [!WARNING]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Add 'Failed' enum value to upload status

Revision ID: bfa5449b6bba
Revises: eab5bf7ec514
Create Date: 2024-06-28 15:18:51.744902

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = 'bfa5449b6bba'
down_revision = 'eab5bf7ec514'
branch_labels = None
depends_on = None


def upgrade():
# Add new value to the enum type
op.execute("ALTER TYPE uploadstatus ADD VALUE 'FAILED'")

def downgrade():
# Downgrade logic to remove the 'FAILED' value is not straightforward
# Enum types in PostgreSQL cannot remove a value directly
# So, we need to create a new enum type without 'FAILED', convert the column, and drop the old type

# Create a new enum type without 'FAILED'
op.execute("CREATE TYPE uploadstatus_tmp AS ENUM('IN_PROGRESS', 'COMPLETED')")

# Alter the column to use the new enum type
op.alter_column(
'upload',
'status',
existing_type=postgresql.ENUM('IN_PROGRESS', 'COMPLETED', 'FAILED', name='uploadstatus'),
type_=postgresql.ENUM('IN_PROGRESS', 'COMPLETED', name='uploadstatus_tmp'),
existing_nullable=True
)

# Drop the old enum type
op.execute("DROP TYPE uploadstatus")

# Rename the new enum type to the old name
op.execute("ALTER TYPE uploadstatus_tmp RENAME TO uploadstatus")
58 changes: 48 additions & 10 deletions backend/app/api/routes/uploads.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime
from tempfile import NamedTemporaryFile
from typing import IO, Annotated, Any
Expand Down Expand Up @@ -31,6 +32,8 @@

router = APIRouter()

qdrant_store = QdrantStore()


async def valid_content_length(
content_length: int = Header(..., le=settings.MAX_UPLOAD_SIZE),
Expand Down Expand Up @@ -70,11 +73,36 @@ def save_file_if_within_size_limit(file: UploadFile, file_size: int) -> IO[bytes
return temp


def update_upload_status(session: SessionDep, upload: Upload) -> None:
"""Set upload status to completed"""
upload.status = UploadStatus.COMPLETED
session.add(upload)
session.commit()
def process_add(
file_path: str,
upload_id: int,
user_id: int,
chunk_size: int,
chunk_overlap: int,
update_status_callback: Callable[[UploadStatus], None],
) -> None:
try:
qdrant_store.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
update_status_callback(UploadStatus.COMPLETED)
except Exception as e:
update_status_callback(UploadStatus.FAILED)
raise e


def process_update(
file_path: str,
upload_id: int,
user_id: int,
chunk_size: int,
chunk_overlap: int,
update_status_callback: Callable[[UploadStatus], None],
) -> None:
try:
qdrant_store.update(file_path, upload_id, user_id, chunk_size, chunk_overlap)
update_status_callback(UploadStatus.COMPLETED)
except Exception as e:
update_status_callback(UploadStatus.FAILED)
raise e


@router.get("/", response_model=UploadsOut)
Expand Down Expand Up @@ -136,14 +164,19 @@ def create_upload(
status_code=500, detail="Failed to retrieve user and upload ID"
)

def update_status_callback(status: UploadStatus) -> None:
upload.status = status
session.add(upload)
session.commit()

background_tasks.add_task(
QdrantStore().create,
process_add,
temp_file.name,
upload.id,
current_user.id,
chunk_size,
chunk_overlap,
lambda: update_upload_status(session, upload),
update_status_callback,
)
except Exception as e:
session.delete(upload)
Expand Down Expand Up @@ -201,14 +234,19 @@ def update_upload(
session.add(upload)
session.commit()

def update_status_callback(status: UploadStatus) -> None:
upload.status = status
session.add(upload)
session.commit()

background_tasks.add_task(
QdrantStore().update,
process_update,
temp_file.name,
id,
upload.owner_id,
chunk_size,
chunk_overlap,
lambda: update_upload_status(session, upload),
update_status_callback,
)

session.commit()
Expand All @@ -228,7 +266,7 @@ def delete_upload(session: SessionDep, current_user: CurrentUser, id: int) -> Me
session.delete(upload)
if upload.owner_id is None:
raise HTTPException(status_code=500, detail="Failed to retrieve owner ID")
QdrantStore().delete(id, upload.owner_id)
qdrant_store.delete(id, upload.owner_id)
session.commit()
except Exception as e:
session.rollback()
Expand Down
3 changes: 2 additions & 1 deletion backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def _enforce_non_default_secrets(self) -> Self:
QDRANT_COLLECTION: str = "uploads"

# Embeddings
EMBEDDING_MODEL: str
DENSE_EMBEDDING_MODEL: str
SPARSE_EMBEDDING_MODEL: str

MAX_UPLOAD_SIZE: int = 50_000_000

Expand Down
100 changes: 74 additions & 26 deletions backend/app/core/graph/rag/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from collections.abc import Callable
from typing import Any

import pymupdf4llm # type: ignore[import-untyped]
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_qdrant import Qdrant
from langchain_text_splitters import MarkdownTextSplitter
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

from app.core.config import settings
from app.core.graph.rag.qdrant_retriever import QdrantRetriever


class QdrantStore:
"""
A class to handle uploading and searching documents in a Qdrant vector store.
"""

embeddings = FastEmbedEmbeddings(model_name=settings.EMBEDDING_MODEL) # type: ignore[call-arg]
embeddings = FastEmbedEmbeddings(model_name=settings.DENSE_EMBEDDING_MODEL) # type: ignore[call-arg]
collection_name = settings.QDRANT_COLLECTION
url = settings.QDRANT_URL

def create(
def __init__(self) -> None:
self.client = self._create_collection()

def add(
self,
file_path: str,
upload_id: int,
Expand All @@ -46,22 +51,46 @@ def create(
[md_text],
[{"user_id": user_id, "upload_id": upload_id}],
)
Qdrant.from_documents(
documents=docs,
embedding=self.embeddings,
url=self.url,
prefer_grpc=True,

doc_texts: list[str] = []
metadata: list[dict[Any, Any]] = []
for doc in docs:
doc_texts.append(doc.page_content)
metadata.append(doc.metadata)

self.client.add(
collection_name=self.collection_name,
api_key=settings.QDRANT__SERVICE__API_KEY,
documents=doc_texts,
metadata=metadata,
)

callback() if callback else None

def _create_collection(self) -> QdrantClient:
"""
Creates a collection in Qdrant if it does not already exist, configured for hybrid search.

The collection uses both dense and sparse vector models. Returns an instance of the Qdrant client.

Returns:
QdrantClient: An instance of the Qdrant client.
"""
client = QdrantClient(url=self.url, api_key=settings.QDRANT__SERVICE__API_KEY)
client.set_model(settings.DENSE_EMBEDDING_MODEL)
client.set_sparse_model(settings.SPARSE_EMBEDDING_MODEL)
if not client.collection_exists(self.collection_name):
client.create_collection(
collection_name=self.collection_name,
vectors_config=client.get_fastembed_vector_params(),
sparse_vectors_config=client.get_fastembed_sparse_vector_params(),
)
return client

def _get_collection(self) -> Qdrant:
"""Get instance of an existing Qdrant collection."""
"""Get instance of an existing Qdrant collection from langchain_qdrant."""
return Qdrant.from_existing_collection(
embedding=self.embeddings,
url=self.url,
prefer_grpc=True,
collection_name=self.collection_name,
api_key=settings.QDRANT__SERVICE__API_KEY,
)
Expand All @@ -73,11 +102,11 @@ def delete(self, upload_id: int, user_id: int) -> bool | None:
ids=rest.Filter( # type: ignore[arg-type]
must=[
rest.FieldCondition(
key="metadata.user_id",
key="user_id",
match=rest.MatchValue(value=user_id),
),
rest.FieldCondition(
key="metadata.upload_id",
key="upload_id",
match=rest.MatchValue(value=upload_id),
),
]
Expand All @@ -95,10 +124,10 @@ def update(
) -> None:
"""Delete and re-upload the new PDF document to the Qdrant vector store"""
self.delete(user_id, upload_id)
self.create(file_path, upload_id, user_id, chunk_size, chunk_overlap)
self.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
callback() if callback else None

def retriever(self, user_id: int, upload_id: int) -> VectorStoreRetriever:
def retriever(self, user_id: int, upload_id: int) -> QdrantRetriever:
"""
Creates a VectorStoreRetriever that retrieves results containing the specified user_id and upload_id in the metadata.

Expand All @@ -109,9 +138,21 @@ def retriever(self, user_id: int, upload_id: int) -> VectorStoreRetriever:
Returns:
VectorStoreRetriever: A VectorStoreRetriever instance.
"""
qdrant = self._get_collection()
retriever = qdrant.as_retriever(
search_kwargs={"filter": {"user_id": user_id, "upload_id": upload_id}}
retriever = QdrantRetriever(
client=self.client,
collection_name=self.collection_name,
search_kwargs=rest.Filter(
must=[
rest.FieldCondition(
key="user_id",
match=rest.MatchValue(value=user_id),
),
rest.FieldCondition(
key="upload_id",
match=rest.MatchValue(value=upload_id),
),
],
),
)
return retriever

Expand All @@ -127,20 +168,27 @@ def search(self, user_id: int, upload_ids: list[int], query: str) -> list[Docume
Returns:
List[Document]: A list of documents matching the search criteria.
"""
qdrant = self._get_collection()
found_docs = qdrant.similarity_search(
query,
filter=rest.Filter(
search_results = self.client.query(
collection_name=self.collection_name,
query_text=query,
query_filter=rest.Filter(
must=[
rest.FieldCondition(
key="metadata.user_id",
key="user_id",
match=rest.MatchValue(value=user_id),
),
rest.FieldCondition(
key="metadata.upload_id",
key="upload_id",
match=rest.MatchAny(any=upload_ids),
),
]
],
),
)
return found_docs
documents: list[Document] = []
for result in search_results:
document = Document(
page_content=result.document,
metadata={"score": result.score},
)
documents.append(document)
return documents
Loading