Skip to content

Commit 01269af

Browse files
authored
feat: direct search (#99)
Added methods `direct_search` and `direct_search_scroll`. Added example for scrolling through all chunks in `scroll_all_chunks.py`.
1 parent 74050c3 commit 01269af

File tree

5 files changed

+226
-2
lines changed

5 files changed

+226
-2
lines changed

cohere/compass/clients/compass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
CompassSdkStage,
5050
CreateDataSource,
5151
DataSource,
52+
DirectSearchInput,
53+
DirectSearchResponse,
54+
DirectSearchScrollInput,
5255
Document,
5356
DocumentStatus,
5457
ParseableDocument,
@@ -115,7 +118,6 @@ def __init__(
115118
)
116119
self.default_max_retries = default_max_retries
117120
self.default_sleep_retry_seconds = default_sleep_retry_seconds
118-
119121
self.api_method = {
120122
"create_index": self.session.put,
121123
"list_indexes": self.session.get,
@@ -130,6 +132,8 @@ def __init__(
130132
"refresh": self.session.post,
131133
"upload_documents": self.session.post,
132134
"update_group_authorization": self.session.post,
135+
"direct_search": self.session.post,
136+
"direct_search_scroll": self.session.post,
133137
# Data Sources APIs
134138
"create_datasource": self.session.post,
135139
"list_datasources": self.session.get,
@@ -152,6 +156,8 @@ def __init__(
152156
"refresh": "/api/v1/indexes/{index_name}/_refresh",
153157
"upload_documents": "/api/v1/indexes/{index_name}/documents/_upload",
154158
"update_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501
159+
"direct_search": "/api/v1/indexes/{index_name}/_direct_search",
160+
"direct_search_scroll": "/api/v1/indexes/_direct_search/scroll",
155161
# Data Sources APIs
156162
"create_datasource": "/api/v1/datasources",
157163
"list_datasources": "/api/v1/datasources",
@@ -881,6 +887,77 @@ def update_group_authorization(
881887
raise CompassError(result.error)
882888
return PutDocumentsResponse.model_validate(result.result)
883889

890+
def direct_search(
891+
self,
892+
*,
893+
index_name: str,
894+
query: dict[str, Any],
895+
size: int = 100,
896+
scroll: Optional[str] = None,
897+
max_retries: Optional[int] = None,
898+
sleep_retry_seconds: Optional[int] = None,
899+
) -> DirectSearchResponse:
900+
"""
901+
Perform a direct search query against the Compass API.
902+
903+
:param index_name: the name of the index
904+
:param query: the direct search query (e.g. {"match_all": {}})
905+
:param size: the number of results to return
906+
:param scroll: the scroll duration (e.g. "1m" for 1 minute)
907+
:param max_retries: the maximum number of times to retry the request
908+
:param sleep_retry_seconds: the number of seconds to sleep between retries
909+
910+
:returns: the direct search results
911+
:raises CompassError: if the search fails
912+
"""
913+
data = DirectSearchInput(query=query, size=size, scroll=scroll)
914+
915+
result = self._send_request(
916+
api_name="direct_search",
917+
index_name=index_name,
918+
data=data,
919+
max_retries=max_retries,
920+
sleep_retry_seconds=sleep_retry_seconds,
921+
)
922+
923+
if result.error:
924+
raise CompassError(result.error)
925+
926+
return DirectSearchResponse.model_validate(result.result)
927+
928+
def direct_search_scroll(
929+
self,
930+
*,
931+
scroll_id: str,
932+
scroll: str = "1m",
933+
max_retries: Optional[int] = None,
934+
sleep_retry_seconds: Optional[int] = None,
935+
) -> DirectSearchResponse:
936+
"""
937+
Continue a search using a scroll ID from a previous direct_search call.
938+
939+
:param scroll_id: the scroll ID from a previous direct_search call
940+
:param scroll: the scroll duration (e.g. "1m" for 1 minute)
941+
:param max_retries: the maximum number of times to retry the request
942+
:param sleep_retry_seconds: the number of seconds to sleep between retries
943+
944+
:returns: the next batch of search results
945+
:raises CompassError: if the scroll search fails
946+
"""
947+
data = DirectSearchScrollInput(scroll_id=scroll_id, scroll=scroll)
948+
949+
result = self._send_request(
950+
api_name="direct_search_scroll",
951+
data=data,
952+
max_retries=max_retries,
953+
sleep_retry_seconds=sleep_retry_seconds,
954+
)
955+
956+
if result.error:
957+
raise CompassError(result.error)
958+
959+
return DirectSearchResponse.model_validate(result.result)
960+
884961
# todo Simplify this method so we don't have to ignore the C901 complexity warning.
885962
def _send_request( # noqa: C901
886963
self,

cohere/compass/models/search.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,25 @@ class SearchInput(BaseModel):
8484
query: str
8585
top_k: int
8686
filters: Optional[list[SearchFilter]] = None
87+
88+
89+
class DirectSearchInput(BaseModel):
90+
"""Input to direct search APIs."""
91+
92+
query: dict[str, Any]
93+
size: int
94+
scroll: Optional[str] = None
95+
96+
97+
class DirectSearchScrollInput(BaseModel):
98+
"""Input to direct search scroll API."""
99+
100+
scroll_id: str
101+
scroll: str
102+
103+
104+
class DirectSearchResponse(BaseModel):
105+
"""Response object for direct search APIs."""
106+
107+
hits: list[RetrievedChunkExtended]
108+
scroll_id: Optional[str] = None
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
import json
3+
4+
from compass_sdk_examples.utils import get_compass_api
5+
6+
7+
def parse_args():
8+
"""
9+
Parse the user arguments using argparse.
10+
"""
11+
parser = argparse.ArgumentParser(
12+
description="""
13+
This script retrieves all chunks from an existing index in Compass using pagination.
14+
""".strip(),
15+
add_help=True,
16+
)
17+
18+
parser.add_argument(
19+
"--index-name",
20+
type=str,
21+
help="Specify the name of the index to retrieve chunks from.",
22+
required=True,
23+
)
24+
parser.add_argument(
25+
"--query",
26+
type=str,
27+
help='JSON string of the query to use (default: {"match_all": {}})',
28+
default='{"match_all": {}}',
29+
)
30+
parser.add_argument(
31+
"--batch-size",
32+
type=int,
33+
help="Number of documents to retrieve per batch (default: 100)",
34+
default=100,
35+
)
36+
parser.add_argument(
37+
"--scroll",
38+
type=str,
39+
help="Scroll duration (default: '1m')",
40+
default="1m",
41+
)
42+
43+
return parser.parse_args()
44+
45+
46+
def main():
47+
args = parse_args()
48+
index_name = args.index_name
49+
50+
try:
51+
query = json.loads(args.query)
52+
except json.JSONDecodeError:
53+
print("Error: Invalid JSON in query argument")
54+
return
55+
56+
client = get_compass_api()
57+
58+
# Inline scroll_all_chunks functionality
59+
if query is None:
60+
query = {"match_all": {}} # type: ignore
61+
62+
response = client.direct_search(
63+
index_name=index_name,
64+
query=query, # type: ignore
65+
size=args.batch_size,
66+
scroll=args.scroll,
67+
)
68+
69+
all_chunks = response.hits
70+
71+
while response.hits and response.scroll_id:
72+
response = client.direct_search_scroll(
73+
scroll_id=response.scroll_id,
74+
scroll=args.scroll,
75+
)
76+
all_chunks.extend(response.hits)
77+
78+
results = all_chunks
79+
80+
print(f"Retrieved {len(results)} total documents")
81+
82+
if results:
83+
print("\nPreview of first document:")
84+
print(results[0])
85+
86+
87+
if __name__ == "__main__":
88+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "compass-sdk"
3-
version = "0.17.0"
3+
version = "0.18.0"
44
authors = []
55
description = "Compass SDK"
66
readme = "README.md"

tests/test_compass_client.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,40 @@ def test_get_document_asset_image(requests_mock: Mocker):
166166
assert isinstance(asset, bytes)
167167
assert asset == b"test"
168168
assert content_type == "image/png"
169+
170+
171+
def test_direct_search_is_valid(requests_mock: Mocker):
172+
# Register mock response for the direct_search endpoint
173+
requests_mock.post(
174+
"http://test.com/api/v1/indexes/test_index/_direct_search",
175+
json={"hits": [], "scroll_id": "test_scroll_id"},
176+
)
177+
178+
compass = CompassClient(index_url="http://test.com")
179+
compass.direct_search(index_name="test_index", query={"match_all": {}})
180+
assert requests_mock.request_history[0].method == "POST"
181+
assert (
182+
requests_mock.request_history[0].url
183+
== "http://test.com/api/v1/indexes/test_index/_direct_search"
184+
)
185+
assert "query" in requests_mock.request_history[0].json()
186+
assert "size" in requests_mock.request_history[0].json()
187+
188+
189+
def test_direct_search_scroll_is_valid(requests_mock: Mocker):
190+
# Register mock response for the direct_search_scroll endpoint
191+
requests_mock.post(
192+
"http://test.com/api/v1/indexes/_direct_search/scroll",
193+
json={"hits": [], "scroll_id": "test_scroll_id"},
194+
)
195+
196+
compass = CompassClient(index_url="http://test.com")
197+
compass.direct_search_scroll(scroll_id="test_scroll_id")
198+
assert requests_mock.request_history[0].method == "POST"
199+
assert (
200+
requests_mock.request_history[0].url
201+
== "http://test.com/api/v1/indexes/_direct_search/scroll"
202+
)
203+
request_body = requests_mock.request_history[0].json()
204+
assert request_body["scroll_id"] == "test_scroll_id"
205+
assert request_body["scroll"] == "1m"

0 commit comments

Comments
 (0)