diff --git a/comps/agent/langchain/src/agent.py b/comps/agent/langchain/src/agent.py index 0533826c5..7ec29835d 100644 --- a/comps/agent/langchain/src/agent.py +++ b/comps/agent/langchain/src/agent.py @@ -1,5 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from comps.cores.proto.agents import AgentConfig from .utils import load_python_prompt @@ -10,6 +11,14 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False): else: custom_prompt = None + agent_config = AgentConfig( + model = args.llm_engine, + with_memory = with_memory, + custom_prompt = custom_prompt, + tools = args.tools + enable_session_persistence=False, + ) + if strategy == "react_langchain": from .strategy.react import ReActAgentwithLangchain @@ -22,7 +31,7 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False): print("Initializing ReAct Agent with LLAMA") from .strategy.react import ReActAgentLlama - return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt) + return ReActAgentLlama(args, agent_config) elif strategy == "plan_execute": from .strategy.planexec import PlanExecuteAgentWithLangGraph diff --git a/comps/agent/langchain/src/strategy/base_agent.py b/comps/agent/langchain/src/strategy/base_agent.py index beb4fa9f8..29f9b7f82 100644 --- a/comps/agent/langchain/src/strategy/base_agent.py +++ b/comps/agent/langchain/src/strategy/base_agent.py @@ -8,7 +8,7 @@ class BaseAgent: - def __init__(self, args, local_vars=None, **kwargs) -> None: + def __init__(self, args, local_vars=None, agent_config=None, **kwargs) -> None: self.llm = setup_chat_model(args) self.tools_descriptions = get_tools_descriptions(args.tools) self.app = None @@ -18,6 +18,21 @@ def __init__(self, args, local_vars=None, **kwargs) -> None: adapt_custom_prompt(local_vars, kwargs.get("custom_prompt")) print(self.tools_descriptions) + self.storage = None + if agent_config.enable_session_persistence: + from llama_stack.providers.utils.kvstore import kvstore_impl + from llama_stack.providers.utils.kvstore import KVStoreConfig + # need async + # self.persistence_store = await kvstore_impl(self.config.persistence_store) + self.persistence_store = await kvstore_impl(KVStoreConfig()) + + await self.persistence_store.set( + key=f"agent:{self.id}", + value=agent_config.json(), + ) + + self.storage = AgentPersistence(self.id, self.persistence_store) + @property def is_vllm(self): return self.args.llm_engine == "vllm" @@ -38,3 +53,6 @@ def execute(self, state: dict): def non_streaming_run(self, query, config): raise NotImplementedError + + async def create_session(self, name: str) -> str: + return await self.storage.create_session(name) diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index f574b5f65..098c18a50 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -210,8 +210,8 @@ def __call__(self, state): class ReActAgentLlama(BaseAgent): - def __init__(self, args, with_memory=False, **kwargs): - super().__init__(args, local_vars=globals(), **kwargs) + def __init__(self, args, agent_config=None, **kwargs): + super().__init__(args, local_vars=globals(), agent_config=agent_config, **kwargs) agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args) tool_node = ToolNode(self.tools_descriptions) @@ -265,7 +265,26 @@ def should_continue(self, state: AgentState): return "continue" def prepare_initial_state(self, query): - return {"messages": [HumanMessage(content=query)]} + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + + messages = [] + if len(turns) == 0 and self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for i, turn in enumerate(turns): + messages.extend(self.turn_to_messages(turn)) + + messages.extend(request.messages) + + self.turn_id = str(uuid.uuid4()) + + # return {"messages": [HumanMessage(content=query)]} + return {"messages": messages} async def stream_generator(self, query, config): initial_state = self.prepare_initial_state(query) @@ -277,6 +296,17 @@ async def stream_generator(self, query, config): if v is not None: yield f"{k}: {v}\n" + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + yield f"data: {repr(event)}\n\n" yield "data: [DONE]\n\n" except Exception as e: @@ -292,6 +322,17 @@ async def non_streaming_run(self, query, config): else: message.pretty_print() + turn = Turn( + turn_id=turn_id, + session_id=request.session_id, + input_messages=request.messages, + output_message=output_message, + started_at=start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + last_message = s["messages"][-1] print("******Response: ", last_message.content) return last_message.content diff --git a/comps/cores/proto/agents/__init__.py b/comps/cores/proto/agents/__init__.py new file mode 100644 index 000000000..1ceec47f8 --- /dev/null +++ b/comps/cores/proto/agents/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .agents import AgentConfig diff --git a/comps/cores/proto/agents/agents.py b/comps/cores/proto/agents/agents.py new file mode 100644 index 000000000..c21e2af28 --- /dev/null +++ b/comps/cores/proto/agents/agents.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, ConfigDict, Field + + +class AgentConfig(BaseModel): + model: str = None + instructions: str = None + enable_session_persistence: bool = False + with_memory: bool = False + tools: str = None + custom_prompt: str = None