Skip to content

Commit 894cfcb

Browse files
authored
feat: tests and bug fixes AgentState.state (#1058)
1 parent 0712141 commit 894cfcb

File tree

5 files changed

+151
-91
lines changed

5 files changed

+151
-91
lines changed

memgpt/agent.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99

1010
from memgpt.metadata import MetadataStore
1111
from memgpt.agent_store.storage import StorageConnector, TableType
12-
from memgpt.data_types import AgentState, Message, EmbeddingConfig, Passage
12+
from memgpt.data_types import AgentState, Message, LLMConfig, EmbeddingConfig, Passage, Preset
1313
from memgpt.models import chat_completion_response
1414
from memgpt.interface import AgentInterface
1515
from memgpt.persistence_manager import LocalStateManager
1616
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
1717
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages
1818
from memgpt.llm_api_tools import create, is_context_overflow_error
1919
from memgpt.utils import (
20+
create_random_username,
2021
get_tool_call_id,
2122
get_local_time,
2223
parse_json,
@@ -167,37 +168,81 @@ def initialize_message_sequence(
167168
class Agent(object):
168169
def __init__(
169170
self,
170-
agent_state: AgentState,
171171
interface: AgentInterface,
172+
# agents can be created from providing agent_state
173+
agent_state: Optional[AgentState] = None,
174+
# or from providing a preset (requires preset + extra fields)
175+
preset: Optional[Preset] = None,
176+
created_by: Optional[uuid.UUID] = None,
177+
name: Optional[str] = None,
178+
llm_config: Optional[LLMConfig] = None,
179+
embedding_config: Optional[EmbeddingConfig] = None,
172180
# extras
173181
messages_total: Optional[int] = None, # TODO remove?
174182
first_message_verify_mono: bool = True, # TODO move to config?
175183
):
184+
185+
# An agent can be created from a Preset object
186+
if preset is not None:
187+
assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)"
188+
assert created_by is not None, "Must provide created_by field when creating an Agent from a Preset"
189+
assert llm_config is not None, "Must provide llm_config field when creating an Agent from a Preset"
190+
assert embedding_config is not None, "Must provide embedding_config field when creating an Agent from a Preset"
191+
192+
# if agent_state is also provided, override any preset values
193+
init_agent_state = AgentState(
194+
name=name if name else create_random_username(),
195+
user_id=created_by,
196+
persona=preset.persona,
197+
human=preset.human,
198+
llm_config=llm_config,
199+
embedding_config=embedding_config,
200+
preset=preset.name, # TODO link via preset.id instead of name?
201+
state={
202+
"persona": preset.persona,
203+
"human": preset.human,
204+
"system": preset.system,
205+
"functions": preset.functions_schema,
206+
"messages": None,
207+
},
208+
)
209+
210+
# An agent can also be created directly from AgentState
211+
elif agent_state is not None:
212+
assert preset is None, "Can create an agent from a Preset or AgentState (but both were provided)"
213+
assert agent_state.state is not None and agent_state.state != {}, "AgentState.state cannot be empty"
214+
215+
# Assume the agent_state passed in is formatted correctly
216+
init_agent_state = agent_state
217+
218+
else:
219+
raise ValueError("Both Preset and AgentState were null (must provide one or the other)")
220+
176221
# Hold a copy of the state that was used to init the agent
177-
self.agent_state = agent_state
222+
self.agent_state = init_agent_state
178223

179224
# gpt-4, gpt-3.5-turbo, ...
180-
self.model = agent_state.llm_config.model
225+
self.model = self.agent_state.llm_config.model
181226

182227
# Store the system instructions (used to rebuild memory)
183-
if "system" not in agent_state.state:
228+
if "system" not in self.agent_state.state:
184229
raise ValueError(f"'system' not found in provided AgentState")
185-
self.system = agent_state.state["system"]
230+
self.system = self.agent_state.state["system"]
186231

187-
if "functions" not in agent_state.state:
232+
if "functions" not in self.agent_state.state:
188233
raise ValueError(f"'functions' not found in provided AgentState")
189234
# Store the functions schemas (this is passed as an argument to ChatCompletion)
190-
self.functions = agent_state.state["functions"] # these are the schema
235+
self.functions = self.agent_state.state["functions"] # these are the schema
191236
# Link the actual python functions corresponding to the schemas
192237
self.functions_python = {k: v["python_function"] for k, v in link_functions(function_schemas=self.functions).items()}
193238
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
194239

195240
# Initialize the memory object
196-
if "persona" not in agent_state.state:
241+
if "persona" not in self.agent_state.state:
197242
raise ValueError(f"'persona' not found in provided AgentState")
198-
if "human" not in agent_state.state:
243+
if "human" not in self.agent_state.state:
199244
raise ValueError(f"'human' not found in provided AgentState")
200-
self.memory = initialize_memory(ai_notes=agent_state.state["persona"], human_notes=agent_state.state["human"])
245+
self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"])
201246

202247
# Interface must implement:
203248
# - internal_monologue
@@ -210,7 +255,7 @@ def __init__(
210255

211256
# Create the persistence manager object based on the AgentState info
212257
# TODO
213-
self.persistence_manager = LocalStateManager(agent_state=agent_state)
258+
self.persistence_manager = LocalStateManager(agent_state=self.agent_state)
214259

215260
# State needed for heartbeat pausing
216261
self.pause_heartbeats_start = None
@@ -226,17 +271,17 @@ def __init__(
226271
self._messages: List[Message] = []
227272

228273
# Once the memory object is initialized, use it to "bake" the system message
229-
if "messages" in agent_state.state and agent_state.state["messages"] is not None:
274+
if "messages" in self.agent_state.state and self.agent_state.state["messages"] is not None:
230275
# print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}")
231-
if not isinstance(agent_state.state["messages"], list):
232-
raise ValueError(f"'messages' in AgentState was bad type: {type(agent_state.state['messages'])}")
233-
assert all([isinstance(msg, str) for msg in agent_state.state["messages"]])
276+
if not isinstance(self.agent_state.state["messages"], list):
277+
raise ValueError(f"'messages' in AgentState was bad type: {type(self.agent_state.state['messages'])}")
278+
assert all([isinstance(msg, str) for msg in self.agent_state.state["messages"]])
234279

235280
# Convert to IDs, and pull from the database
236281
raw_messages = [
237-
self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in agent_state.state["messages"]
282+
self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in self.agent_state.state["messages"]
238283
]
239-
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, agent_state.state["messages"])
284+
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
240285
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
241286

242287
else:

memgpt/cli/cli.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def server(
412412

413413
def run(
414414
persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None,
415-
agent: Annotated[Optional[str], typer.Option(help="Specify agent save file")] = None,
415+
agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None,
416416
human: Annotated[Optional[str], typer.Option(help="Specify human")] = None,
417417
preset: Annotated[Optional[str], typer.Option(help="Specify preset")] = None,
418418
# model flags
@@ -605,18 +605,13 @@ def run(
605605
ms.update_agent(agent_state)
606606

607607
# create agent
608-
memgpt_agent = Agent(agent_state, interface=interface())
608+
memgpt_agent = Agent(agent_state=agent_state, interface=interface())
609609

610610
else: # create new agent
611611
# create new agent config: override defaults with args if provided
612612
typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE)
613613

614-
if agent is None:
615-
# determine agent name
616-
# agent_count = len(ms.list_agents(user_id=user.id))
617-
# agent = f"agent_{agent_count}"
618-
agent = utils.create_random_username()
619-
614+
agent_name = agent if agent else utils.create_random_username()
620615
llm_config = config.default_llm_config
621616
embedding_config = config.default_embedding_config # TODO allow overriding embedding params via CLI run
622617

@@ -649,68 +644,43 @@ def run(
649644
)
650645
llm_config.model_endpoint_type = model_endpoint_type
651646

652-
agent_state = AgentState(
653-
name=agent,
654-
user_id=user.id,
655-
persona=persona if persona else config.persona,
656-
human=human if human else config.human,
657-
preset=preset if preset else config.preset,
658-
llm_config=llm_config,
659-
embedding_config=embedding_config,
660-
)
661-
# ms.create_agent(agent_state)
662-
663-
typer.secho(f"-> 🤖 Using persona profile '{agent_state.persona}'", fg=typer.colors.WHITE)
664-
typer.secho(f"-> 🧑 Using human profile '{agent_state.human}'", fg=typer.colors.WHITE)
665-
666-
# Supress llama-index noise
667-
# TODO(swooders) add persistence manager code? or comment out?
668-
# with suppress_stdout():
669-
# TODO: allow configrable state manager (only local is supported right now)
670-
# persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill
671-
672647
# create agent
673648
try:
674-
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id)
675-
if preset is None:
649+
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
650+
if preset_obj is None:
676651
# create preset records in metadata store
677652
from memgpt.presets.presets import add_default_presets
678653

679654
add_default_presets(user.id, ms)
680655
# try again
681-
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id)
682-
assert preset is not None, "Couldn't find presets in database, please run `memgpt configure`"
656+
preset_obj = ms.get_preset(preset_name=preset if preset else config.preset, user_id=user.id)
657+
if preset_obj is None:
658+
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
659+
sys.exit(1)
683660

684-
memgpt_agent = presets.create_agent_from_preset(
685-
agent_state=agent_state,
686-
preset=preset,
661+
# Overwrite fields in the preset if they were specified
662+
preset_obj.human = human if human else config.human
663+
preset_obj.persona = persona if persona else config.persona
664+
665+
typer.secho(f"-> 🤖 Using persona profile '{preset_obj.persona}'", fg=typer.colors.WHITE)
666+
typer.secho(f"-> 🧑 Using human profile '{preset_obj.human}'", fg=typer.colors.WHITE)
667+
668+
memgpt_agent = Agent(
687669
interface=interface(),
670+
name=agent_name,
671+
created_by=user.id,
672+
preset=preset_obj,
673+
llm_config=llm_config,
674+
embedding_config=embedding_config,
675+
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
676+
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
688677
)
689678
save_agent(agent=memgpt_agent, ms=ms)
679+
690680
except ValueError as e:
691-
# TODO(swooders) what's the equivalent cleanup code for the new DB refactor?
692681
typer.secho(f"Failed to create agent from provided information:\n{e}", fg=typer.colors.RED)
693-
# # Delete the directory of the failed agent
694-
# try:
695-
# # Path to the specific file
696-
# agent_config_file = agent_config.agent_config_path
697-
698-
# # Check if the file exists
699-
# if os.path.isfile(agent_config_file):
700-
# # Delete the file
701-
# os.remove(agent_config_file)
702-
703-
# # Now, delete the directory along with any remaining files in it
704-
# agent_save_dir = os.path.join(MEMGPT_DIR, "agents", agent_config.name)
705-
# shutil.rmtree(agent_save_dir)
706-
# except:
707-
# typer.secho(f"Failed to delete agent directory during cleanup:\n{e}", fg=typer.colors.RED)
708682
sys.exit(1)
709-
typer.secho(f"🎉 Created new agent '{agent_state.name}' (id={agent_state.id})", fg=typer.colors.GREEN)
710-
711-
# pretty print agent config
712-
# printd(json.dumps(vars(agent_config), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
713-
# printd(json.dumps(agent_init_state), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
683+
typer.secho(f"🎉 Created new agent '{memgpt_agent.agent_state.name}' (id={memgpt_agent.agent_state.id})", fg=typer.colors.GREEN)
714684

715685
# start event loop
716686
from memgpt.main import run_agent_loop

memgpt/presets/presets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def create_agent_from_preset(
6868
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
6969
):
7070
"""Initialize a new agent from a preset (combination of system + function)"""
71+
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")
7172

7273
# Input validation
7374
if agent_state.persona is None:

memgpt/server/server.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -598,34 +598,43 @@ def create_agent(
598598
if not user:
599599
raise ValueError(f"cannot find user with associated client id: {user_id}")
600600

601-
agent_state = AgentState(
602-
user_id=user.id,
603-
name=name if name else utils.create_random_username(),
604-
preset=preset if preset else self.config.preset,
605-
# TODO we need to allow passing raw persona/human text via the server request
606-
persona=persona if persona else self.config.persona,
607-
human=human if human else self.config.human,
608-
llm_config=llm_config if llm_config else self.server_llm_config,
609-
embedding_config=embedding_config if embedding_config else self.server_embedding_config,
610-
)
611601
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
612602
# TODO: fix this db dependency and remove
613603
# self.ms.create_agent(agent_state)
614604

615-
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
616605
try:
617-
preset = self.ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
618-
assert preset is not None, f"preset {agent_state.preset} does not exist"
619-
620-
agent = presets.create_agent_from_preset(agent_state=agent_state, preset=preset, interface=interface)
606+
preset_obj = self.ms.get_preset(preset_name=preset if preset else self.config.preset, user_id=user_id)
607+
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
608+
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
609+
610+
# Overwrite fields in the preset if they were specified
611+
preset_obj.human = human if human else self.config.human
612+
preset_obj.persona = persona if persona else self.config.persona
613+
614+
llm_config = llm_config if llm_config else self.server_llm_config
615+
embedding_config = embedding_config if embedding_config else self.server_embedding_config
616+
617+
agent = Agent(
618+
interface=interface,
619+
preset=preset_obj,
620+
name=name,
621+
created_by=user.id,
622+
llm_config=llm_config,
623+
embedding_config=embedding_config,
624+
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
625+
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
626+
)
621627
save_agent(agent=agent, ms=self.ms)
622628

623629
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
624630
# self.ms.update_agent(agent.agent_state)
625631
except Exception as e:
626632
logger.exception(e)
627-
self.ms.delete_agent(agent_id=agent_state.id)
628-
raise
633+
try:
634+
self.ms.delete_agent(agent_id=agent.agent_state.id)
635+
except Exception as delete_e:
636+
logger.exception(f"Failed to delete_agent:\n{delete_e}")
637+
raise e
629638

630639
save_agent(agent, self.ms)
631640

tests/test_metadata_store.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
33
import pytest
44

5+
from memgpt.agent import Agent, save_agent
56
from memgpt.metadata import MetadataStore
67
from memgpt.config import MemGPTConfig
78
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
9+
from memgpt.utils import get_human_text, get_persona_text
810

911

1012
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@@ -50,6 +52,39 @@ def test_storage(storage_connector):
5052
len(ms.list_sources(user_id=user_1.id)) == 1
5153
len(ms.list_sources(user_id=user_2.id)) == 0
5254

55+
# test agent_state saving
56+
agent_state = ms.get_agent(agent_1.id).state
57+
assert agent_state is None, agent_state # when created via create_agent, it should be empty
58+
59+
from memgpt.presets.presets import add_default_presets
60+
61+
add_default_presets(user_1.id, ms)
62+
preset_obj = ms.get_preset(preset_name=DEFAULT_PRESET, user_id=user_1.id)
63+
from memgpt.interface import CLIInterface as interface # for printing to terminal
64+
65+
# Overwrite fields in the preset if they were specified
66+
preset_obj.human = get_human_text(DEFAULT_HUMAN)
67+
preset_obj.persona = get_persona_text(DEFAULT_PERSONA)
68+
69+
# Create the agent
70+
agent = Agent(
71+
interface=interface(),
72+
created_by=user_1.id,
73+
name="agent_test_agent_state",
74+
preset=preset_obj,
75+
llm_config=config.default_llm_config,
76+
embedding_config=config.default_embedding_config,
77+
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
78+
first_message_verify_mono=(
79+
True if (config.default_llm_config.model is not None and "gpt-4" in config.default_llm_config.model) else False
80+
),
81+
)
82+
agent_with_agent_state = agent.agent_state
83+
save_agent(agent=agent, ms=ms)
84+
85+
agent_state = ms.get_agent(agent_with_agent_state.id).state
86+
assert agent_state is not None, agent_state # when created via create_agent_from_preset, it should be non-empty
87+
5388
# test: updating
5489

5590
# test: update JSON-stored LLMConfig class

0 commit comments

Comments
 (0)