|
| 1 | +import chromadb |
| 2 | +import json |
| 3 | +import re |
| 4 | +from typing import Optional, List, Iterator |
| 5 | +from memgpt.connectors.storage import StorageConnector, Passage |
| 6 | +from memgpt.utils import printd |
| 7 | +from memgpt.config import AgentConfig, MemGPTConfig |
| 8 | + |
| 9 | + |
| 10 | +def create_chroma_client(): |
| 11 | + config = MemGPTConfig.load() |
| 12 | + # create chroma client |
| 13 | + if config.archival_storage_path: |
| 14 | + client = chromadb.PersistentClient(config.archival_storage_path) |
| 15 | + else: |
| 16 | + # assume uri={ip}:{port} |
| 17 | + ip = config.archival_storage_uri.split(":")[0] |
| 18 | + port = config.archival_storage_uri.split(":")[1] |
| 19 | + client = chromadb.HttpClient(host=ip, port=port) |
| 20 | + return client |
| 21 | + |
| 22 | + |
| 23 | +class ChromaStorageConnector(StorageConnector): |
| 24 | + """Storage via Chroma""" |
| 25 | + |
| 26 | + # WARNING: This is not thread safe. Do NOT do concurrent access to the same collection. |
| 27 | + |
| 28 | + def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): |
| 29 | + # determine table name |
| 30 | + if agent_config: |
| 31 | + assert name is None, f"Cannot specify both agent config and name {name}" |
| 32 | + self.table_name = self.generate_table_name_agent(agent_config) |
| 33 | + elif name: |
| 34 | + assert agent_config is None, f"Cannot specify both agent config and name {name}" |
| 35 | + self.table_name = self.generate_table_name(name) |
| 36 | + else: |
| 37 | + raise ValueError("Must specify either agent config or name") |
| 38 | + |
| 39 | + printd(f"Using table name {self.table_name}") |
| 40 | + |
| 41 | + # create client |
| 42 | + self.client = create_chroma_client() |
| 43 | + |
| 44 | + # get a collection or create if it doesn't exist already |
| 45 | + self.collection = self.client.get_or_create_collection(self.table_name) |
| 46 | + |
| 47 | + def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: |
| 48 | + offset = 0 |
| 49 | + while True: |
| 50 | + # Retrieve a chunk of records with the given page_size |
| 51 | + db_passages_chunk = self.collection.get(offset=offset, limit=page_size, include=["embeddings", "documents"]) |
| 52 | + |
| 53 | + # If the chunk is empty, we've retrieved all records |
| 54 | + if not db_passages_chunk: |
| 55 | + break |
| 56 | + |
| 57 | + # Yield a list of Passage objects converted from the chunk |
| 58 | + yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk] |
| 59 | + |
| 60 | + # Increment the offset to get the next chunk in the next iteration |
| 61 | + offset += page_size |
| 62 | + |
| 63 | + def get_all(self) -> List[Passage]: |
| 64 | + results = self.collection.get(include=["embeddings", "documents"]) |
| 65 | + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])] |
| 66 | + |
| 67 | + def get(self, id: str) -> Optional[Passage]: |
| 68 | + results = self.collection.get(ids=[id]) |
| 69 | + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])] |
| 70 | + |
| 71 | + def insert(self, passage: Passage): |
| 72 | + self.collection.add(documents=[passage.text], embeddings=[passage.embedding], ids=[str(self.collection.count())]) |
| 73 | + |
| 74 | + def insert_many(self, passages: List[Passage], show_progress=True): |
| 75 | + count = self.collection.count() |
| 76 | + ids = [str(count + i) for i in range(len(passages))] |
| 77 | + self.collection.add( |
| 78 | + documents=[passage.text for passage in passages], embeddings=[passage.embedding for passage in passages], ids=ids |
| 79 | + ) |
| 80 | + |
| 81 | + def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: |
| 82 | + results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=["embeddings", "documents"]) |
| 83 | + # get index [0] since query is passed as list |
| 84 | + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"][0], results["embeddings"][0])] |
| 85 | + |
| 86 | + def delete(self): |
| 87 | + self.client.delete_collection(name=self.table_name) |
| 88 | + |
| 89 | + def save(self): |
| 90 | + # save to persistence file (nothing needs to be done) |
| 91 | + printd("Saving chroma") |
| 92 | + pass |
| 93 | + |
| 94 | + @staticmethod |
| 95 | + def list_loaded_data(): |
| 96 | + client = create_chroma_client() |
| 97 | + collections = client.list_collections() |
| 98 | + collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")] |
| 99 | + return collections |
| 100 | + |
| 101 | + def sanitize_table_name(self, name: str) -> str: |
| 102 | + # Remove leading and trailing whitespace |
| 103 | + name = name.strip() |
| 104 | + |
| 105 | + # Replace spaces and invalid characters with underscores |
| 106 | + name = re.sub(r"\s+|\W+", "_", name) |
| 107 | + |
| 108 | + # Truncate to the maximum identifier length (e.g., 63 for PostgreSQL) |
| 109 | + max_length = 63 |
| 110 | + if len(name) > max_length: |
| 111 | + name = name[:max_length].rstrip("_") |
| 112 | + |
| 113 | + # Convert to lowercase |
| 114 | + name = name.lower() |
| 115 | + |
| 116 | + return name |
| 117 | + |
| 118 | + def generate_table_name_agent(self, agent_config: AgentConfig): |
| 119 | + return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" |
| 120 | + |
| 121 | + def generate_table_name(self, name: str): |
| 122 | + return f"memgpt_{self.sanitize_table_name(name)}" |
| 123 | + |
| 124 | + def size(self) -> int: |
| 125 | + return self.collection.count() |
0 commit comments