99
1010from memgpt .metadata import MetadataStore
1111from memgpt .agent_store .storage import StorageConnector , TableType
12- from memgpt .data_types import AgentState , Message , EmbeddingConfig , Passage
12+ from memgpt .data_types import AgentState , Message , LLMConfig , EmbeddingConfig , Passage , Preset
1313from memgpt .models import chat_completion_response
1414from memgpt .interface import AgentInterface
1515from memgpt .persistence_manager import LocalStateManager
1616from memgpt .system import get_login_event , package_function_response , package_summarize_message , get_initial_boot_messages
1717from memgpt .memory import CoreMemory as InContextMemory , summarize_messages
1818from memgpt .llm_api_tools import create , is_context_overflow_error
1919from memgpt .utils import (
20+ create_random_username ,
2021 get_tool_call_id ,
2122 get_local_time ,
2223 parse_json ,
@@ -167,37 +168,81 @@ def initialize_message_sequence(
167168class Agent (object ):
168169 def __init__ (
169170 self ,
170- agent_state : AgentState ,
171171 interface : AgentInterface ,
172+ # agents can be created from providing agent_state
173+ agent_state : Optional [AgentState ] = None ,
174+ # or from providing a preset (requires preset + extra fields)
175+ preset : Optional [Preset ] = None ,
176+ created_by : Optional [uuid .UUID ] = None ,
177+ name : Optional [str ] = None ,
178+ llm_config : Optional [LLMConfig ] = None ,
179+ embedding_config : Optional [EmbeddingConfig ] = None ,
172180 # extras
173181 messages_total : Optional [int ] = None , # TODO remove?
174182 first_message_verify_mono : bool = True , # TODO move to config?
175183 ):
184+
185+ # An agent can be created from a Preset object
186+ if preset is not None :
187+ assert agent_state is None , "Can create an agent from a Preset or AgentState (but both were provided)"
188+ assert created_by is not None , "Must provide created_by field when creating an Agent from a Preset"
189+ assert llm_config is not None , "Must provide llm_config field when creating an Agent from a Preset"
190+ assert embedding_config is not None , "Must provide embedding_config field when creating an Agent from a Preset"
191+
192+ # if agent_state is also provided, override any preset values
193+ init_agent_state = AgentState (
194+ name = name if name else create_random_username (),
195+ user_id = created_by ,
196+ persona = preset .persona ,
197+ human = preset .human ,
198+ llm_config = llm_config ,
199+ embedding_config = embedding_config ,
200+ preset = preset .name , # TODO link via preset.id instead of name?
201+ state = {
202+ "persona" : preset .persona ,
203+ "human" : preset .human ,
204+ "system" : preset .system ,
205+ "functions" : preset .functions_schema ,
206+ "messages" : None ,
207+ },
208+ )
209+
210+ # An agent can also be created directly from AgentState
211+ elif agent_state is not None :
212+ assert preset is None , "Can create an agent from a Preset or AgentState (but both were provided)"
213+ assert agent_state .state is not None and agent_state .state != {}, "AgentState.state cannot be empty"
214+
215+ # Assume the agent_state passed in is formatted correctly
216+ init_agent_state = agent_state
217+
218+ else :
219+ raise ValueError ("Both Preset and AgentState were null (must provide one or the other)" )
220+
176221 # Hold a copy of the state that was used to init the agent
177- self .agent_state = agent_state
222+ self .agent_state = init_agent_state
178223
179224 # gpt-4, gpt-3.5-turbo, ...
180- self .model = agent_state .llm_config .model
225+ self .model = self . agent_state .llm_config .model
181226
182227 # Store the system instructions (used to rebuild memory)
183- if "system" not in agent_state .state :
228+ if "system" not in self . agent_state .state :
184229 raise ValueError (f"'system' not found in provided AgentState" )
185- self .system = agent_state .state ["system" ]
230+ self .system = self . agent_state .state ["system" ]
186231
187- if "functions" not in agent_state .state :
232+ if "functions" not in self . agent_state .state :
188233 raise ValueError (f"'functions' not found in provided AgentState" )
189234 # Store the functions schemas (this is passed as an argument to ChatCompletion)
190- self .functions = agent_state .state ["functions" ] # these are the schema
235+ self .functions = self . agent_state .state ["functions" ] # these are the schema
191236 # Link the actual python functions corresponding to the schemas
192237 self .functions_python = {k : v ["python_function" ] for k , v in link_functions (function_schemas = self .functions ).items ()}
193238 assert all ([callable (f ) for k , f in self .functions_python .items ()]), self .functions_python
194239
195240 # Initialize the memory object
196- if "persona" not in agent_state .state :
241+ if "persona" not in self . agent_state .state :
197242 raise ValueError (f"'persona' not found in provided AgentState" )
198- if "human" not in agent_state .state :
243+ if "human" not in self . agent_state .state :
199244 raise ValueError (f"'human' not found in provided AgentState" )
200- self .memory = initialize_memory (ai_notes = agent_state .state ["persona" ], human_notes = agent_state .state ["human" ])
245+ self .memory = initialize_memory (ai_notes = self . agent_state .state ["persona" ], human_notes = self . agent_state .state ["human" ])
201246
202247 # Interface must implement:
203248 # - internal_monologue
@@ -210,7 +255,7 @@ def __init__(
210255
211256 # Create the persistence manager object based on the AgentState info
212257 # TODO
213- self .persistence_manager = LocalStateManager (agent_state = agent_state )
258+ self .persistence_manager = LocalStateManager (agent_state = self . agent_state )
214259
215260 # State needed for heartbeat pausing
216261 self .pause_heartbeats_start = None
@@ -226,17 +271,17 @@ def __init__(
226271 self ._messages : List [Message ] = []
227272
228273 # Once the memory object is initialized, use it to "bake" the system message
229- if "messages" in agent_state .state and agent_state .state ["messages" ] is not None :
274+ if "messages" in self . agent_state .state and self . agent_state .state ["messages" ] is not None :
230275 # print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}")
231- if not isinstance (agent_state .state ["messages" ], list ):
232- raise ValueError (f"'messages' in AgentState was bad type: { type (agent_state .state ['messages' ])} " )
233- assert all ([isinstance (msg , str ) for msg in agent_state .state ["messages" ]])
276+ if not isinstance (self . agent_state .state ["messages" ], list ):
277+ raise ValueError (f"'messages' in AgentState was bad type: { type (self . agent_state .state ['messages' ])} " )
278+ assert all ([isinstance (msg , str ) for msg in self . agent_state .state ["messages" ]])
234279
235280 # Convert to IDs, and pull from the database
236281 raw_messages = [
237- self .persistence_manager .recall_memory .storage .get (id = uuid .UUID (msg_id )) for msg_id in agent_state .state ["messages" ]
282+ self .persistence_manager .recall_memory .storage .get (id = uuid .UUID (msg_id )) for msg_id in self . agent_state .state ["messages" ]
238283 ]
239- assert all ([isinstance (msg , Message ) for msg in raw_messages ]), (raw_messages , agent_state .state ["messages" ])
284+ assert all ([isinstance (msg , Message ) for msg in raw_messages ]), (raw_messages , self . agent_state .state ["messages" ])
240285 self ._messages .extend ([cast (Message , msg ) for msg in raw_messages if msg is not None ])
241286
242287 else :
0 commit comments