Skip to content

Commit a176056

Browse files
committed
feat: support seperate memory
1 parent 0e3fab1 commit a176056

File tree

7 files changed

+835
-12
lines changed

7 files changed

+835
-12
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from dataclasses import dataclass, field
5+
from typing import Any
6+
from uuid import UUID, uuid5
7+
8+
from sqlalchemy import select
9+
from sqlalchemy.orm import Session
10+
11+
from core.model_manager import ModelInstance
12+
from core.model_runtime.entities import (
13+
AssistantPromptMessage,
14+
PromptMessage,
15+
PromptMessageRole,
16+
TextPromptMessageContent,
17+
UserPromptMessage,
18+
)
19+
from core.variables.segments import ObjectSegment
20+
from core.variables.types import SegmentType
21+
from extensions.ext_database import db
22+
from factories import variable_factory
23+
from models.workflow import ConversationVariable
24+
25+
# A stable namespace to derive deterministic IDs for node-scoped memories.
26+
# Using uuid5 to keep (conversation_id, node_id) mapping stable across runs.
27+
# NOTE: UUIDs must contain only hexadecimal characters; avoid letters beyond 'f'.
28+
NODE_SCOPED_MEMORY_NS = UUID("00000000-0000-0000-0000-000000000000")
29+
30+
31+
@dataclass
32+
class _HistoryItem:
33+
role: str
34+
text: str
35+
36+
37+
@dataclass
38+
class NodeScopedMemory:
39+
"""A per-node conversation memory persisted in ConversationVariable.
40+
41+
- Keyed by (conversation_id, node_id)
42+
- Value is stored as a conversation variable named _llm_mem.<node_id>
43+
- Structure (JSON): {"version": 1, "history": [{"role": "user"|"assistant", "text": "..."}, ...]}
44+
"""
45+
46+
app_id: str
47+
conversation_id: str
48+
node_id: str
49+
model_instance: ModelInstance
50+
51+
_loaded: bool = field(default=False, init=False)
52+
_history: list[_HistoryItem] = field(default_factory=list, init=False)
53+
54+
@property
55+
def variable_name(self) -> str:
56+
return f"_llm_mem.{self.node_id}"
57+
58+
@property
59+
def variable_id(self) -> str:
60+
# Deterministic id so we can upsert by (id, conversation_id)
61+
return str(uuid5(NODE_SCOPED_MEMORY_NS, f"{self.conversation_id}:{self.node_id}:llmmem"))
62+
63+
# ------------ Persistence helpers ------------
64+
def _load_if_needed(self) -> None:
65+
if self._loaded:
66+
return
67+
stmt = select(ConversationVariable).where(
68+
ConversationVariable.id == self.variable_id,
69+
ConversationVariable.conversation_id == self.conversation_id,
70+
)
71+
with Session(db.engine, expire_on_commit=False) as session:
72+
row = session.scalar(stmt)
73+
if not row:
74+
self._history = []
75+
self._loaded = True
76+
return
77+
variable = row.to_variable()
78+
value = variable.value if isinstance(variable.value, dict) else {}
79+
hist = value.get("history", []) if isinstance(value, dict) else []
80+
parsed: list[_HistoryItem] = []
81+
for item in hist:
82+
try:
83+
role = str(item.get("role", ""))
84+
text = str(item.get("text", ""))
85+
except Exception:
86+
role, text = "", ""
87+
if role and text:
88+
parsed.append(_HistoryItem(role=role, text=text))
89+
self._history = parsed
90+
self._loaded = True
91+
92+
def _dump_variable(self) -> Any:
93+
data = {
94+
"version": 1,
95+
"history": [{"role": item.role, "text": item.text} for item in self._history if item.text],
96+
}
97+
segment = ObjectSegment(value=data, value_type=SegmentType.OBJECT)
98+
variable = variable_factory.segment_to_variable(
99+
segment=segment,
100+
selector=["conversation", self.variable_name],
101+
id=self.variable_id,
102+
name=self.variable_name,
103+
description="LLM node-scoped memory",
104+
)
105+
return variable
106+
107+
def save(self) -> None:
108+
variable = self._dump_variable()
109+
with Session(db.engine) as session:
110+
# Upsert by (id, conversation_id)
111+
existing = session.scalar(
112+
select(ConversationVariable).where(
113+
ConversationVariable.id == self.variable_id,
114+
ConversationVariable.conversation_id == self.conversation_id,
115+
)
116+
)
117+
if existing:
118+
existing.data = variable.model_dump_json()
119+
else:
120+
obj = ConversationVariable.from_variable(
121+
app_id=self.app_id, conversation_id=self.conversation_id, variable=variable
122+
)
123+
session.add(obj)
124+
session.commit()
125+
126+
# ------------ Public API expected by LLM node ------------
127+
def get_history_prompt_messages(
128+
self, *, max_token_limit: int = 2000, message_limit: int | None = None
129+
) -> Sequence[PromptMessage]:
130+
self._load_if_needed()
131+
132+
# Optionally limit by message count (pairs flattened)
133+
items: list[_HistoryItem] = list(self._history)
134+
if message_limit and message_limit > 0:
135+
# message_limit roughly means last N items (not pairs) to keep simple and efficient
136+
items = items[-min(message_limit, len(items)) :]
137+
138+
def to_messages(hist: list[_HistoryItem]) -> list[PromptMessage]:
139+
msgs: list[PromptMessage] = []
140+
for it in hist:
141+
if it.role == PromptMessageRole.USER.value:
142+
# Persisted node memory only stores text; inject as plain text content
143+
msgs.append(UserPromptMessage(content=it.text))
144+
elif it.role == PromptMessageRole.ASSISTANT.value:
145+
msgs.append(AssistantPromptMessage(content=it.text))
146+
return msgs
147+
148+
messages = to_messages(items)
149+
# Token-based pruning from oldest
150+
if messages:
151+
tokens = self.model_instance.get_llm_num_tokens(messages)
152+
while tokens > max_token_limit and len(messages) > 1:
153+
messages.pop(0)
154+
tokens = self.model_instance.get_llm_num_tokens(messages)
155+
return messages
156+
157+
def get_history_prompt_text(
158+
self,
159+
*,
160+
human_prefix: str = "Human",
161+
ai_prefix: str = "Assistant",
162+
max_token_limit: int = 2000,
163+
message_limit: int | None = None,
164+
) -> str:
165+
self._load_if_needed()
166+
items: list[_HistoryItem] = list(self._history)
167+
if message_limit and message_limit > 0:
168+
items = items[-min(message_limit, len(items)) :]
169+
170+
# Build messages to reuse token counting logic
171+
messages: list[PromptMessage] = []
172+
for it in items:
173+
role_name = (
174+
PromptMessageRole.USER
175+
if it.role == PromptMessageRole.USER.value
176+
else (PromptMessageRole.ASSISTANT if it.role == PromptMessageRole.ASSISTANT.value else None)
177+
)
178+
if role_name is None:
179+
continue
180+
prefix = human_prefix if role_name == PromptMessageRole.USER else ai_prefix
181+
messages.append(
182+
UserPromptMessage(content=f"{prefix}: {it.text}")
183+
if role_name == PromptMessageRole.USER
184+
else AssistantPromptMessage(content=f"{prefix}: {it.text}")
185+
)
186+
187+
if messages:
188+
tokens = self.model_instance.get_llm_num_tokens(messages)
189+
while tokens > max_token_limit and len(messages) > 1:
190+
messages.pop(0)
191+
tokens = self.model_instance.get_llm_num_tokens(messages)
192+
193+
# Convert back to the required text format
194+
lines: list[str] = []
195+
for m in messages:
196+
if m.role == PromptMessageRole.USER:
197+
prefix = human_prefix
198+
elif m.role == PromptMessageRole.ASSISTANT:
199+
prefix = ai_prefix
200+
else:
201+
continue
202+
if isinstance(m.content, list):
203+
# Only text content was saved in this minimal implementation
204+
texts = [c.data for c in m.content if isinstance(c, TextPromptMessageContent)]
205+
text = "\n".join(texts)
206+
else:
207+
text = str(m.content)
208+
lines.append(f"{prefix}: {text}")
209+
return "\n".join(lines)
210+
211+
def append_exchange(self, *, user_text: str | None, assistant_text: str | None) -> None:
212+
self._load_if_needed()
213+
if user_text:
214+
self._history.append(_HistoryItem(role=PromptMessageRole.USER.value, text=user_text))
215+
if assistant_text:
216+
self._history.append(_HistoryItem(role=PromptMessageRole.ASSISTANT.value, text=assistant_text))
217+
218+
def clear(self) -> None:
219+
self._history = []
220+
self._loaded = True
221+
self.save()

api/core/prompt/entities/advanced_prompt_entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@ class WindowConfig(BaseModel):
4848
role_prefix: RolePrefix | None = None
4949
window: WindowConfig
5050
query_prompt_template: str | None = None
51+
# Memory scope: shared (default, uses TokenBufferMemory), or independent
52+
# (per-node, persisted in ConversationVariable)
53+
scope: Literal["shared", "independent"] = "shared"
54+
# If true, clear the per-node memory after this node finishes execution (only applies to independent scope)
55+
clear_after_execution: bool = False

api/core/workflow/nodes/llm/llm_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
99
from core.entities.provider_entities import QuotaUnit
1010
from core.file.models import File
11+
from core.memory.node_scoped_memory import NodeScopedMemory
1112
from core.memory.token_buffer_memory import TokenBufferMemory
1213
from core.model_manager import ModelInstance, ModelManager
1314
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -107,6 +108,28 @@ def fetch_memory(
107108
return memory
108109

109110

111+
def fetch_node_scoped_memory(
112+
variable_pool: VariablePool,
113+
*,
114+
app_id: str,
115+
node_id: str,
116+
model_instance: ModelInstance,
117+
) -> NodeScopedMemory | None:
118+
"""Factory for per-node memory based on conversation scope.
119+
120+
Returns None if no conversation_id is present in the variable pool.
121+
"""
122+
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
123+
if not isinstance(conversation_id_variable, StringSegment):
124+
return None
125+
return NodeScopedMemory(
126+
app_id=app_id,
127+
conversation_id=conversation_id_variable.value,
128+
node_id=node_id,
129+
model_instance=model_instance,
130+
)
131+
132+
110133
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
111134
provider_model_bundle = model_instance.provider_model_bundle
112135
provider_configuration = provider_model_bundle.configuration

api/core/workflow/nodes/llm/node.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import re
66
import time
77
from collections.abc import Generator, Mapping, Sequence
8-
from typing import TYPE_CHECKING, Any, Literal
8+
from typing import TYPE_CHECKING, Any, Literal, Protocol
99

1010
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
1111
from core.file import FileType, file_manager
1212
from core.helper.code_executor import CodeExecutor, CodeLanguage
1313
from core.llm_generator.output_parser.errors import OutputParserError
1414
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
15-
from core.memory.token_buffer_memory import TokenBufferMemory
1615
from core.model_manager import ModelInstance, ModelManager
1716
from core.model_runtime.entities import (
1817
ImagePromptMessageContent,
@@ -100,6 +99,26 @@
10099
logger = logging.getLogger(__name__)
101100

102101

102+
class ChatHistoryMemory(Protocol):
103+
def get_history_prompt_messages(
104+
self,
105+
*,
106+
max_token_limit: int = 2000,
107+
message_limit: int | None = None,
108+
) -> Sequence[PromptMessage]:
109+
...
110+
111+
def get_history_prompt_text(
112+
self,
113+
*,
114+
human_prefix: str = "Human",
115+
ai_prefix: str = "Assistant",
116+
max_token_limit: int = 2000,
117+
message_limit: int | None = None,
118+
) -> str:
119+
...
120+
121+
103122
class LLMNode(Node):
104123
node_type = NodeType.LLM
105124

@@ -215,13 +234,26 @@ def _run(self) -> Generator:
215234
tenant_id=self.tenant_id,
216235
)
217236

218-
# fetch memory
219-
memory = llm_utils.fetch_memory(
220-
variable_pool=variable_pool,
221-
app_id=self.app_id,
222-
node_data_memory=self._node_data.memory,
223-
model_instance=model_instance,
237+
# fetch memory (shared) or node-scoped (independent)
238+
independent_scope = (
239+
self._node_data.memory and getattr(self._node_data.memory, "scope", "shared") == "independent"
224240
)
241+
node_memory = None
242+
memory_shared = None
243+
if independent_scope:
244+
node_memory = llm_utils.fetch_node_scoped_memory(
245+
variable_pool=variable_pool,
246+
app_id=self.app_id,
247+
node_id=self._node_id,
248+
model_instance=model_instance,
249+
)
250+
else:
251+
memory_shared = llm_utils.fetch_memory(
252+
variable_pool=variable_pool,
253+
app_id=self.app_id,
254+
node_data_memory=self._node_data.memory,
255+
model_instance=model_instance,
256+
)
225257

226258
query: str | None = None
227259
if self._node_data.memory:
@@ -235,7 +267,7 @@ def _run(self) -> Generator:
235267
sys_query=query,
236268
sys_files=files,
237269
context=context,
238-
memory=memory,
270+
memory=(node_memory if independent_scope else memory_shared),
239271
model_config=model_config,
240272
prompt_template=self._node_data.prompt_template,
241273
memory_config=self._node_data.memory,
@@ -289,6 +321,10 @@ def _run(self) -> Generator:
289321
else None
290322
)
291323

324+
# Persist node-scoped memory if enabled
325+
if independent_scope and node_memory:
326+
node_memory.clear()
327+
292328
# deduct quota
293329
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
294330
break
@@ -756,7 +792,7 @@ def fetch_prompt_messages(
756792
sys_query: str | None = None,
757793
sys_files: Sequence["File"],
758794
context: str | None = None,
759-
memory: TokenBufferMemory | None = None,
795+
memory: ChatHistoryMemory | None = None,
760796
model_config: ModelConfigWithCredentialsEntity,
761797
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
762798
memory_config: MemoryConfig | None = None,
@@ -1297,7 +1333,7 @@ def _calculate_rest_token(
12971333

12981334
def _handle_memory_chat_mode(
12991335
*,
1300-
memory: TokenBufferMemory | None,
1336+
memory: ChatHistoryMemory | None,
13011337
memory_config: MemoryConfig | None,
13021338
model_config: ModelConfigWithCredentialsEntity,
13031339
) -> Sequence[PromptMessage]:
@@ -1314,7 +1350,7 @@ def _handle_memory_chat_mode(
13141350

13151351
def _handle_memory_completion_mode(
13161352
*,
1317-
memory: TokenBufferMemory | None,
1353+
memory: ChatHistoryMemory | None,
13181354
memory_config: MemoryConfig | None,
13191355
model_config: ModelConfigWithCredentialsEntity,
13201356
) -> str:

0 commit comments

Comments
 (0)