diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 55cbac5e1f..7f7d8875f2 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -96,7 +96,12 @@ def load_directory( user_id = uuid.UUID(config.anon_clientid) ms = MetadataStore(config) - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + embedding_model=config.default_embedding_config.embedding_model, + embedding_dim=config.default_embedding_config.embedding_dim, + ) ms.create_source(source) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # TODO: also get document store @@ -209,7 +214,12 @@ def load_vector_database( user_id = uuid.UUID(config.anon_clientid) ms = MetadataStore(config) - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + embedding_model=config.default_embedding_config.embedding_model, + embedding_dim=config.default_embedding_config.embedding_dim, + ) ms.create_source(source) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # TODO: also get document store diff --git a/memgpt/data_sources/connectors.py b/memgpt/data_sources/connectors.py index eaef6cc11a..7d2b544679 100644 --- a/memgpt/data_sources/connectors.py +++ b/memgpt/data_sources/connectors.py @@ -24,6 +24,12 @@ def load_data( document_store: Optional[StorageConnector] = None, ): """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" + assert ( + source.embedding_model == embedding_config.embedding_model + ), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}" + assert ( + source.embedding_dim == embedding_config.embedding_dim + ), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}." # embedding model embed_model = embedding_model(embedding_config) @@ -55,8 +61,8 @@ def load_data( metadata_=passage_metadata, user_id=source.user_id, data_source=source.name, - embedding_dim=embedding_config.embedding_dim, - embedding_model=embedding_config.embedding_model, + embedding_dim=source.embedding_dim, + embedding_model=source.embedding_model, embedding=embedding, ) diff --git a/memgpt/main.py b/memgpt/main.py index 72c0625c6b..010af6af67 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -131,14 +131,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, ): valid_options.append(source.name) else: + # print warning about invalid sources + typer.secho( + f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {memgpt_agent.agent_state.embedding_config.embedding_dim} and model {memgpt_agent.agent_state.embedding_config.embedding_model}", + fg=typer.colors.YELLOW, + ) invalid_options.append(source.name) - # print warning about invalid sources - typer.secho( - f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}", - fg=typer.colors.YELLOW, - ) - # prompt user for data source selection data_source = questionary.select("Select data source", choices=valid_options).ask() diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 28f964ce10..99f4c98da4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,4 +1,5 @@ import json +from datetime import datetime import logging import uuid from abc import abstractmethod @@ -1023,7 +1024,12 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token: def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields """Create a new data source""" - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + embedding_model=self.config.default_embedding_config.embedding_model, + embedding_dim=self.config.default_embedding_config.embedding_dim, + ) self.ms.create_source(source) return source