Skip to content

Commit 9c2e6b7

Browse files
authored
Chroma storage integration (#285)
1 parent aa75fa1 commit 9c2e6b7

File tree

12 files changed

+293
-28
lines changed

12 files changed

+293
-28
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
4343
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
4444
run: |
45-
poetry install -E dev -E postgres -E local
45+
poetry install -E dev -E postgres -E local -E chroma -E lancedb
4646
4747
- name: Set Poetry config
4848
env:

docs/storage.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,34 @@ pip install 'pymemgpt[postgres]'
1818
### Running Postgres
1919
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).
2020

21+
## Chroma
22+
To enable the Chroma storage backend, install the dependencies with:
23+
```
24+
pip install `pymemgpt[chroma]`
25+
```
26+
You can configure Chroma with both the HTTP and persistent storage client via `memgpt configure`. You will need to specify either a persistent storage path or host/port dependending on your client choice. The example below shows how to configure Chroma with local persistent storage:
27+
```
28+
? Select LLM inference provider: openai
29+
? Override default endpoint: https://api.openai.com/v1
30+
? Select default model (recommended: gpt-4): gpt-4
31+
? Select embedding provider: openai
32+
? Select default preset: memgpt_chat
33+
? Select default persona: sam_pov
34+
? Select default human: cs_phd
35+
? Select storage backend for archival data: chroma
36+
? Select chroma backend: persistent
37+
? Enter persistent storage location: /Users/sarahwooders/.memgpt/config/chroma
38+
```
2139

2240
## LanceDB
23-
In order to use the LanceDB backend.
24-
25-
You have to enable the LanceDB backend by running
26-
27-
```
28-
memgpt configure
29-
```
30-
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.
31-
3241
To enable the LanceDB backend, make sure to install the required dependencies with:
3342
```
3443
pip install 'pymemgpt[lancedb]'
3544
```
36-
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
45+
You have to enable the LanceDB backend by running
46+
```
47+
memgpt configure
48+
```
49+
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
3750

3851

39-
## Chroma
40-
(Coming soon)

memgpt/cli/cli_config.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,40 @@ def configure_cli(config: MemGPTConfig):
241241

242242
def configure_archival_storage(config: MemGPTConfig):
243243
# Configure archival storage backend
244-
archival_storage_options = ["local", "lancedb", "postgres"]
244+
archival_storage_options = ["local", "lancedb", "postgres", "chroma"]
245245
archival_storage_type = questionary.select(
246246
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
247247
).ask()
248-
archival_storage_uri = None
248+
archival_storage_uri, archival_storage_path = None, None
249+
250+
# configure postgres
249251
if archival_storage_type == "postgres":
250252
archival_storage_uri = questionary.text(
251253
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
252254
default=config.archival_storage_uri if config.archival_storage_uri else "",
253255
).ask()
254256

257+
# configure lancedb
255258
if archival_storage_type == "lancedb":
256259
archival_storage_uri = questionary.text(
257260
"Enter lanncedb connection string (e.g. ./.lancedb",
258261
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
259262
).ask()
260263

261-
return archival_storage_type, archival_storage_uri
264+
# configure chroma
265+
if archival_storage_type == "chroma":
266+
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="http").ask()
267+
if chroma_type == "http":
268+
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
269+
if chroma_type == "persistent":
270+
print(config.config_path, config.archival_storage_path)
271+
default_archival_storage_path = (
272+
config.archival_storage_path if config.archival_storage_path else os.path.join(config.config_path, "chroma")
273+
)
274+
print(default_archival_storage_path)
275+
archival_storage_path = questionary.text("Enter persistent storage location:", default=default_archival_storage_path).ask()
276+
277+
return archival_storage_type, archival_storage_uri, archival_storage_path
262278

263279
# TODO: allow configuring embedding model
264280

@@ -275,7 +291,7 @@ def configure():
275291
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
276292
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
277293
default_preset, default_persona, default_human, default_agent = configure_cli(config)
278-
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
294+
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
279295

280296
# check credentials
281297
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
@@ -322,6 +338,7 @@ def configure():
322338
# storage
323339
archival_storage_type=archival_storage_type,
324340
archival_storage_uri=archival_storage_uri,
341+
archival_storage_path=archival_storage_path,
325342
)
326343
print(f"Saving config to {config.config_path}")
327344
config.save()

memgpt/cli/cli_load.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ def load_directory(
102102
reader = SimpleDirectoryReader(input_files=input_files)
103103

104104
# load docs
105-
print("loading data")
106105
docs = reader.load_data()
107-
print("done loading data")
108106
store_docs(name, docs)
109107

110108

memgpt/connectors/chroma.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import chromadb
2+
import json
3+
import re
4+
from typing import Optional, List, Iterator
5+
from memgpt.connectors.storage import StorageConnector, Passage
6+
from memgpt.utils import printd
7+
from memgpt.config import AgentConfig, MemGPTConfig
8+
9+
10+
def create_chroma_client():
11+
config = MemGPTConfig.load()
12+
# create chroma client
13+
if config.archival_storage_path:
14+
client = chromadb.PersistentClient(config.archival_storage_path)
15+
else:
16+
# assume uri={ip}:{port}
17+
ip = config.archival_storage_uri.split(":")[0]
18+
port = config.archival_storage_uri.split(":")[1]
19+
client = chromadb.HttpClient(host=ip, port=port)
20+
return client
21+
22+
23+
class ChromaStorageConnector(StorageConnector):
24+
"""Storage via Chroma"""
25+
26+
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
27+
28+
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
29+
# determine table name
30+
if agent_config:
31+
assert name is None, f"Cannot specify both agent config and name {name}"
32+
self.table_name = self.generate_table_name_agent(agent_config)
33+
elif name:
34+
assert agent_config is None, f"Cannot specify both agent config and name {name}"
35+
self.table_name = self.generate_table_name(name)
36+
else:
37+
raise ValueError("Must specify either agent config or name")
38+
39+
printd(f"Using table name {self.table_name}")
40+
41+
# create client
42+
self.client = create_chroma_client()
43+
44+
# get a collection or create if it doesn't exist already
45+
self.collection = self.client.get_or_create_collection(self.table_name)
46+
47+
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
48+
offset = 0
49+
while True:
50+
# Retrieve a chunk of records with the given page_size
51+
db_passages_chunk = self.collection.get(offset=offset, limit=page_size, include=["embeddings", "documents"])
52+
53+
# If the chunk is empty, we've retrieved all records
54+
if not db_passages_chunk:
55+
break
56+
57+
# Yield a list of Passage objects converted from the chunk
58+
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]
59+
60+
# Increment the offset to get the next chunk in the next iteration
61+
offset += page_size
62+
63+
def get_all(self) -> List[Passage]:
64+
results = self.collection.get(include=["embeddings", "documents"])
65+
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]
66+
67+
def get(self, id: str) -> Optional[Passage]:
68+
results = self.collection.get(ids=[id])
69+
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]
70+
71+
def insert(self, passage: Passage):
72+
self.collection.add(documents=[passage.text], embeddings=[passage.embedding], ids=[str(self.collection.count())])
73+
74+
def insert_many(self, passages: List[Passage], show_progress=True):
75+
count = self.collection.count()
76+
ids = [str(count + i) for i in range(len(passages))]
77+
self.collection.add(
78+
documents=[passage.text for passage in passages], embeddings=[passage.embedding for passage in passages], ids=ids
79+
)
80+
81+
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
82+
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=["embeddings", "documents"])
83+
# get index [0] since query is passed as list
84+
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"][0], results["embeddings"][0])]
85+
86+
def delete(self):
87+
self.client.delete_collection(name=self.table_name)
88+
89+
def save(self):
90+
# save to persistence file (nothing needs to be done)
91+
printd("Saving chroma")
92+
pass
93+
94+
@staticmethod
95+
def list_loaded_data():
96+
client = create_chroma_client()
97+
collections = client.list_collections()
98+
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
99+
return collections
100+
101+
def sanitize_table_name(self, name: str) -> str:
102+
# Remove leading and trailing whitespace
103+
name = name.strip()
104+
105+
# Replace spaces and invalid characters with underscores
106+
name = re.sub(r"\s+|\W+", "_", name)
107+
108+
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
109+
max_length = 63
110+
if len(name) > max_length:
111+
name = name[:max_length].rstrip("_")
112+
113+
# Convert to lowercase
114+
name = name.lower()
115+
116+
return name
117+
118+
def generate_table_name_agent(self, agent_config: AgentConfig):
119+
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
120+
121+
def generate_table_name(self, name: str):
122+
return f"memgpt_{self.sanitize_table_name(name)}"
123+
124+
def size(self) -> int:
125+
return self.collection.count()

memgpt/connectors/db.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def list_loaded_data():
157157
inspector = inspect(engine)
158158
tables = inspector.get_table_names()
159159
tables = [table for table in tables if table.startswith("memgpt_") and not table.startswith("memgpt_agent_")]
160-
tables = [table.replace("memgpt_", "") for table in tables]
160+
start_chars = len("memgpt_")
161+
tables = [table[start_chars:] for table in tables]
161162
return tables
162163

163164
def sanitize_table_name(self, name: str) -> str:
@@ -300,7 +301,8 @@ def list_loaded_data():
300301

301302
tables = db.table_names()
302303
tables = [table for table in tables if table.startswith("memgpt_")]
303-
tables = [table.replace("memgpt_", "") for table in tables]
304+
start_chars = len("memgpt_")
305+
tables = [table[start_chars:] for table in tables]
304306
return tables
305307

306308
def sanitize_table_name(self, name: str) -> str:

memgpt/connectors/storage.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from memgpt.config import AgentConfig, MemGPTConfig
1818

1919

20+
from memgpt.config import AgentConfig, MemGPTConfig
21+
22+
2023
class Passage:
2124
"""A passage is a single unit of memory, and a standard format accross all storage backends.
2225
@@ -47,12 +50,14 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age
4750
from memgpt.connectors.db import PostgresStorageConnector
4851

4952
return PostgresStorageConnector(name=name, agent_config=agent_config)
53+
elif storage_type == "chroma":
54+
from memgpt.connectors.chroma import ChromaStorageConnector
5055

56+
return ChromaStorageConnector(name=name, agent_config=agent_config)
5157
elif storage_type == "lancedb":
5258
from memgpt.connectors.db import LanceDBConnector
5359

5460
return LanceDBConnector(name=name, agent_config=agent_config)
55-
5661
else:
5762
raise NotImplementedError(f"Storage type {storage_type} not implemented")
5863

@@ -67,7 +72,10 @@ def list_loaded_data():
6772
from memgpt.connectors.db import PostgresStorageConnector
6873

6974
return PostgresStorageConnector.list_loaded_data()
75+
elif storage_type == "chroma":
76+
from memgpt.connectors.chroma import ChromaStorageConnector
7077

78+
return ChromaStorageConnector.list_loaded_data()
7179
elif storage_type == "lancedb":
7280
from memgpt.connectors.db import LanceDBConnector
7381

memgpt/memory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from llama_index.node_parser import SimpleNodeParser
1212
from llama_index.node_parser import SimpleNodeParser
1313

14-
from memgpt.embeddings import embedding_model
15-
from memgpt.config import MemGPTConfig
16-
1714

1815
class CoreMemory(object):
1916
"""Held in-context inside the system message

poetry.lock

Lines changed: 13 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
4242
websockets = "^12.0"
4343
docstring-parser = "^0.15"
4444
lancedb = {version = "^0.3.3", optional = true}
45+
chroma = {version = "^0.2.0", optional = true}
4546
httpx = "^0.25.2"
4647
numpy = "^1.26.2"
4748
demjson3 = "^3.0.6"
@@ -54,6 +55,7 @@ pyyaml = "^6.0.1"
5455
local = ["torch", "huggingface-hub", "transformers"]
5556
lancedb = ["lancedb"]
5657
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
58+
chroma = ["chroma"]
5759
dev = ["pytest", "black", "pre-commit", "datasets"]
5860

5961
[build-system]

0 commit comments

Comments
 (0)