Skip to content

Commit

Permalink
fix: bug with storing embedding info in metadata store (#1101)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Mar 6, 2024
1 parent 2885002 commit 9280568
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
14 changes: 12 additions & 2 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def load_directory(
user_id = uuid.UUID(config.anon_clientid)

ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
Expand Down Expand Up @@ -209,7 +214,12 @@ def load_vector_database(
user_id = uuid.UUID(config.anon_clientid)

ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
Expand Down
10 changes: 8 additions & 2 deletions memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def load_data(
document_store: Optional[StorageConnector] = None,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
assert (
source.embedding_model == embedding_config.embedding_model
), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}"
assert (
source.embedding_dim == embedding_config.embedding_dim
), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}."

# embedding model
embed_model = embedding_model(embedding_config)
Expand Down Expand Up @@ -55,8 +61,8 @@ def load_data(
metadata_=passage_metadata,
user_id=source.user_id,
data_source=source.name,
embedding_dim=embedding_config.embedding_dim,
embedding_model=embedding_config.embedding_model,
embedding_dim=source.embedding_dim,
embedding_model=source.embedding_model,
embedding=embedding,
)

Expand Down
11 changes: 5 additions & 6 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
):
valid_options.append(source.name)
else:
# print warning about invalid sources
typer.secho(
f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {memgpt_agent.agent_state.embedding_config.embedding_dim} and model {memgpt_agent.agent_state.embedding_config.embedding_model}",
fg=typer.colors.YELLOW,
)
invalid_options.append(source.name)

# print warning about invalid sources
typer.secho(
f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}",
fg=typer.colors.YELLOW,
)

# prompt user for data source selection
data_source = questionary.select("Select data source", choices=valid_options).ask()

Expand Down
8 changes: 7 additions & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from datetime import datetime
import logging
import uuid
from abc import abstractmethod
Expand Down Expand Up @@ -1023,7 +1024,12 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token:

def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
self.ms.create_source(source)
return source

Expand Down

0 comments on commit 9280568

Please sign in to comment.