Skip to content

Commit 145e4b1

Browse files
authored
Enhance RAG with hybrid search (#62)
1 parent 8f44fae commit 145e4b1

12 files changed

+307
-103
lines changed

.env.example

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ OPENAI_API_KEY=
2525
ANTHROPIC_API_KEY=
2626

2727
# Embedding model. See the list of supported models: https://qdrant.github.io/fastembed/examples/Supported_Models/
28-
EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
28+
DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
29+
SPARSE_EMBEDDING_MODEL=prithivida/Splade_PP_en_v1
2930

3031

3132
# Langsmith: For llm observability

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ RAG is a technique for augmenting your agents' knowledge with additional data. A
182182

183183
#### Customising embedding models
184184

185-
By default, Tribe uses `BAAI/bge-small-en-v1.5`, which is a light and fast English embedding model that is better than `OpenAI Ada-002`. If your documents are multilingual or require image embedding, you may want to use another embedding model. You can easily do this by changing `EMBEDDING_MODEL` in your `.env` file:
185+
By default, Tribe uses `BAAI/bge-small-en-v1.5`, which is a light and fast English embedding model that is better than `OpenAI Ada-002`. If your documents are multilingual or require image embedding, you may want to use another embedding model. You can easily do this by changing `DENSE_EMBEDDING_MODEL` in your `.env` file:
186186

187187
```bash
188188
# See the list of supported models: https://qdrant.github.io/fastembed/examples/Supported_Models/
189-
EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this
189+
DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this
190190
```
191191

192192
> [!WARNING]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Add 'Failed' enum value to upload status
2+
3+
Revision ID: bfa5449b6bba
4+
Revises: eab5bf7ec514
5+
Create Date: 2024-06-28 15:18:51.744902
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = 'bfa5449b6bba'
15+
down_revision = 'eab5bf7ec514'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# Add new value to the enum type
22+
op.execute("ALTER TYPE uploadstatus ADD VALUE 'FAILED'")
23+
24+
def downgrade():
25+
# Downgrade logic to remove the 'FAILED' value is not straightforward
26+
# Enum types in PostgreSQL cannot remove a value directly
27+
# So, we need to create a new enum type without 'FAILED', convert the column, and drop the old type
28+
29+
# Create a new enum type without 'FAILED'
30+
op.execute("CREATE TYPE uploadstatus_tmp AS ENUM('IN_PROGRESS', 'COMPLETED')")
31+
32+
# Alter the column to use the new enum type
33+
op.alter_column(
34+
'upload',
35+
'status',
36+
existing_type=postgresql.ENUM('IN_PROGRESS', 'COMPLETED', 'FAILED', name='uploadstatus'),
37+
type_=postgresql.ENUM('IN_PROGRESS', 'COMPLETED', name='uploadstatus_tmp'),
38+
existing_nullable=True
39+
)
40+
41+
# Drop the old enum type
42+
op.execute("DROP TYPE uploadstatus")
43+
44+
# Rename the new enum type to the old name
45+
op.execute("ALTER TYPE uploadstatus_tmp RENAME TO uploadstatus")

backend/app/api/routes/uploads.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Callable
12
from datetime import datetime
23
from tempfile import NamedTemporaryFile
34
from typing import IO, Annotated, Any
@@ -31,6 +32,8 @@
3132

3233
router = APIRouter()
3334

35+
qdrant_store = QdrantStore()
36+
3437

3538
async def valid_content_length(
3639
content_length: int = Header(..., le=settings.MAX_UPLOAD_SIZE),
@@ -70,11 +73,36 @@ def save_file_if_within_size_limit(file: UploadFile, file_size: int) -> IO[bytes
7073
return temp
7174

7275

73-
def update_upload_status(session: SessionDep, upload: Upload) -> None:
74-
"""Set upload status to completed"""
75-
upload.status = UploadStatus.COMPLETED
76-
session.add(upload)
77-
session.commit()
76+
def process_add(
77+
file_path: str,
78+
upload_id: int,
79+
user_id: int,
80+
chunk_size: int,
81+
chunk_overlap: int,
82+
update_status_callback: Callable[[UploadStatus], None],
83+
) -> None:
84+
try:
85+
qdrant_store.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
86+
update_status_callback(UploadStatus.COMPLETED)
87+
except Exception as e:
88+
update_status_callback(UploadStatus.FAILED)
89+
raise e
90+
91+
92+
def process_update(
93+
file_path: str,
94+
upload_id: int,
95+
user_id: int,
96+
chunk_size: int,
97+
chunk_overlap: int,
98+
update_status_callback: Callable[[UploadStatus], None],
99+
) -> None:
100+
try:
101+
qdrant_store.update(file_path, upload_id, user_id, chunk_size, chunk_overlap)
102+
update_status_callback(UploadStatus.COMPLETED)
103+
except Exception as e:
104+
update_status_callback(UploadStatus.FAILED)
105+
raise e
78106

79107

80108
@router.get("/", response_model=UploadsOut)
@@ -136,14 +164,19 @@ def create_upload(
136164
status_code=500, detail="Failed to retrieve user and upload ID"
137165
)
138166

167+
def update_status_callback(status: UploadStatus) -> None:
168+
upload.status = status
169+
session.add(upload)
170+
session.commit()
171+
139172
background_tasks.add_task(
140-
QdrantStore().create,
173+
process_add,
141174
temp_file.name,
142175
upload.id,
143176
current_user.id,
144177
chunk_size,
145178
chunk_overlap,
146-
lambda: update_upload_status(session, upload),
179+
update_status_callback,
147180
)
148181
except Exception as e:
149182
session.delete(upload)
@@ -201,14 +234,19 @@ def update_upload(
201234
session.add(upload)
202235
session.commit()
203236

237+
def update_status_callback(status: UploadStatus) -> None:
238+
upload.status = status
239+
session.add(upload)
240+
session.commit()
241+
204242
background_tasks.add_task(
205-
QdrantStore().update,
243+
process_update,
206244
temp_file.name,
207245
id,
208246
upload.owner_id,
209247
chunk_size,
210248
chunk_overlap,
211-
lambda: update_upload_status(session, upload),
249+
update_status_callback,
212250
)
213251

214252
session.commit()
@@ -228,7 +266,7 @@ def delete_upload(session: SessionDep, current_user: CurrentUser, id: int) -> Me
228266
session.delete(upload)
229267
if upload.owner_id is None:
230268
raise HTTPException(status_code=500, detail="Failed to retrieve owner ID")
231-
QdrantStore().delete(id, upload.owner_id)
269+
qdrant_store.delete(id, upload.owner_id)
232270
session.commit()
233271
except Exception as e:
234272
session.rollback()

backend/app/core/config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def _enforce_non_default_secrets(self) -> Self:
130130
QDRANT_COLLECTION: str = "uploads"
131131

132132
# Embeddings
133-
EMBEDDING_MODEL: str
133+
DENSE_EMBEDDING_MODEL: str
134+
SPARSE_EMBEDDING_MODEL: str
134135

135136
MAX_UPLOAD_SIZE: int = 50_000_000
136137

backend/app/core/graph/rag/qdrant.py

+74-26
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
from collections.abc import Callable
2+
from typing import Any
23

34
import pymupdf4llm # type: ignore[import-untyped]
45
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
56
from langchain_core.documents import Document
6-
from langchain_core.vectorstores import VectorStoreRetriever
77
from langchain_qdrant import Qdrant
88
from langchain_text_splitters import MarkdownTextSplitter
9+
from qdrant_client import QdrantClient
910
from qdrant_client.http import models as rest
1011

1112
from app.core.config import settings
13+
from app.core.graph.rag.qdrant_retriever import QdrantRetriever
1214

1315

1416
class QdrantStore:
1517
"""
1618
A class to handle uploading and searching documents in a Qdrant vector store.
1719
"""
1820

19-
embeddings = FastEmbedEmbeddings(model_name=settings.EMBEDDING_MODEL) # type: ignore[call-arg]
21+
embeddings = FastEmbedEmbeddings(model_name=settings.DENSE_EMBEDDING_MODEL) # type: ignore[call-arg]
2022
collection_name = settings.QDRANT_COLLECTION
2123
url = settings.QDRANT_URL
2224

23-
def create(
25+
def __init__(self) -> None:
26+
self.client = self._create_collection()
27+
28+
def add(
2429
self,
2530
file_path: str,
2631
upload_id: int,
@@ -46,22 +51,46 @@ def create(
4651
[md_text],
4752
[{"user_id": user_id, "upload_id": upload_id}],
4853
)
49-
Qdrant.from_documents(
50-
documents=docs,
51-
embedding=self.embeddings,
52-
url=self.url,
53-
prefer_grpc=True,
54+
55+
doc_texts: list[str] = []
56+
metadata: list[dict[Any, Any]] = []
57+
for doc in docs:
58+
doc_texts.append(doc.page_content)
59+
metadata.append(doc.metadata)
60+
61+
self.client.add(
5462
collection_name=self.collection_name,
55-
api_key=settings.QDRANT__SERVICE__API_KEY,
63+
documents=doc_texts,
64+
metadata=metadata,
5665
)
66+
5767
callback() if callback else None
5868

69+
def _create_collection(self) -> QdrantClient:
70+
"""
71+
Creates a collection in Qdrant if it does not already exist, configured for hybrid search.
72+
73+
The collection uses both dense and sparse vector models. Returns an instance of the Qdrant client.
74+
75+
Returns:
76+
QdrantClient: An instance of the Qdrant client.
77+
"""
78+
client = QdrantClient(url=self.url, api_key=settings.QDRANT__SERVICE__API_KEY)
79+
client.set_model(settings.DENSE_EMBEDDING_MODEL)
80+
client.set_sparse_model(settings.SPARSE_EMBEDDING_MODEL)
81+
if not client.collection_exists(self.collection_name):
82+
client.create_collection(
83+
collection_name=self.collection_name,
84+
vectors_config=client.get_fastembed_vector_params(),
85+
sparse_vectors_config=client.get_fastembed_sparse_vector_params(),
86+
)
87+
return client
88+
5989
def _get_collection(self) -> Qdrant:
60-
"""Get instance of an existing Qdrant collection."""
90+
"""Get instance of an existing Qdrant collection from langchain_qdrant."""
6191
return Qdrant.from_existing_collection(
6292
embedding=self.embeddings,
6393
url=self.url,
64-
prefer_grpc=True,
6594
collection_name=self.collection_name,
6695
api_key=settings.QDRANT__SERVICE__API_KEY,
6796
)
@@ -73,11 +102,11 @@ def delete(self, upload_id: int, user_id: int) -> bool | None:
73102
ids=rest.Filter( # type: ignore[arg-type]
74103
must=[
75104
rest.FieldCondition(
76-
key="metadata.user_id",
105+
key="user_id",
77106
match=rest.MatchValue(value=user_id),
78107
),
79108
rest.FieldCondition(
80-
key="metadata.upload_id",
109+
key="upload_id",
81110
match=rest.MatchValue(value=upload_id),
82111
),
83112
]
@@ -95,10 +124,10 @@ def update(
95124
) -> None:
96125
"""Delete and re-upload the new PDF document to the Qdrant vector store"""
97126
self.delete(user_id, upload_id)
98-
self.create(file_path, upload_id, user_id, chunk_size, chunk_overlap)
127+
self.add(file_path, upload_id, user_id, chunk_size, chunk_overlap)
99128
callback() if callback else None
100129

101-
def retriever(self, user_id: int, upload_id: int) -> VectorStoreRetriever:
130+
def retriever(self, user_id: int, upload_id: int) -> QdrantRetriever:
102131
"""
103132
Creates a VectorStoreRetriever that retrieves results containing the specified user_id and upload_id in the metadata.
104133
@@ -109,9 +138,21 @@ def retriever(self, user_id: int, upload_id: int) -> VectorStoreRetriever:
109138
Returns:
110139
VectorStoreRetriever: A VectorStoreRetriever instance.
111140
"""
112-
qdrant = self._get_collection()
113-
retriever = qdrant.as_retriever(
114-
search_kwargs={"filter": {"user_id": user_id, "upload_id": upload_id}}
141+
retriever = QdrantRetriever(
142+
client=self.client,
143+
collection_name=self.collection_name,
144+
search_kwargs=rest.Filter(
145+
must=[
146+
rest.FieldCondition(
147+
key="user_id",
148+
match=rest.MatchValue(value=user_id),
149+
),
150+
rest.FieldCondition(
151+
key="upload_id",
152+
match=rest.MatchValue(value=upload_id),
153+
),
154+
],
155+
),
115156
)
116157
return retriever
117158

@@ -127,20 +168,27 @@ def search(self, user_id: int, upload_ids: list[int], query: str) -> list[Docume
127168
Returns:
128169
List[Document]: A list of documents matching the search criteria.
129170
"""
130-
qdrant = self._get_collection()
131-
found_docs = qdrant.similarity_search(
132-
query,
133-
filter=rest.Filter(
171+
search_results = self.client.query(
172+
collection_name=self.collection_name,
173+
query_text=query,
174+
query_filter=rest.Filter(
134175
must=[
135176
rest.FieldCondition(
136-
key="metadata.user_id",
177+
key="user_id",
137178
match=rest.MatchValue(value=user_id),
138179
),
139180
rest.FieldCondition(
140-
key="metadata.upload_id",
181+
key="upload_id",
141182
match=rest.MatchAny(any=upload_ids),
142183
),
143-
]
184+
],
144185
),
145186
)
146-
return found_docs
187+
documents: list[Document] = []
188+
for result in search_results:
189+
document = Document(
190+
page_content=result.document,
191+
metadata={"score": result.score},
192+
)
193+
documents.append(document)
194+
return documents

0 commit comments

Comments
 (0)