Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions api/core/memory/node_scoped_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID, uuid5

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
from core.variables.segments import ObjectSegment
from core.variables.types import SegmentType
from extensions.ext_database import db
from factories import variable_factory
from models.workflow import ConversationVariable

# A stable namespace to derive deterministic IDs for node-scoped memories.
# Using uuid5 to keep (conversation_id, node_id) mapping stable across runs.
# NOTE: UUIDs must contain only hexadecimal characters; avoid letters beyond 'f'.
NODE_SCOPED_MEMORY_NS = UUID("00000000-0000-0000-0000-000000000000")


@dataclass
class _HistoryItem:
role: str
text: str


@dataclass
class NodeScopedMemory:
"""A per-node conversation memory persisted in ConversationVariable.
- Keyed by (conversation_id, node_id)
- Value is stored as a conversation variable named _llm_mem.<node_id>
- Structure (JSON): {"version": 1, "history": [{"role": "user"|"assistant", "text": "..."}, ...]}
"""

app_id: str
conversation_id: str
node_id: str
model_instance: ModelInstance

_loaded: bool = field(default=False, init=False)
_history: list[_HistoryItem] = field(default_factory=list, init=False)

@property
def variable_name(self) -> str:
return f"_llm_mem.{self.node_id}"

@property
def variable_id(self) -> str:
# Deterministic id so we can upsert by (id, conversation_id)
return str(uuid5(NODE_SCOPED_MEMORY_NS, f"{self.conversation_id}:{self.node_id}:llmmem"))

# ------------ Persistence helpers ------------
def _load_if_needed(self) -> None:
if self._loaded:
return
stmt = select(ConversationVariable).where(
ConversationVariable.id == self.variable_id,
ConversationVariable.conversation_id == self.conversation_id,
)
with Session(db.engine, expire_on_commit=False) as session:
row = session.scalar(stmt)
if not row:
self._history = []
self._loaded = True
return
variable = row.to_variable()
value = variable.value if isinstance(variable.value, dict) else {}
hist = value.get("history", []) if isinstance(value, dict) else []
parsed: list[_HistoryItem] = []
for item in hist:
try:
role = str(item.get("role", ""))
text = str(item.get("text", ""))
except Exception:
role, text = "", ""
if role and text:
parsed.append(_HistoryItem(role=role, text=text))
self._history = parsed
self._loaded = True

def _dump_variable(self) -> Any:
data = {
"version": 1,
"history": [{"role": item.role, "text": item.text} for item in self._history if item.text],
}
segment = ObjectSegment(value=data, value_type=SegmentType.OBJECT)
variable = variable_factory.segment_to_variable(
segment=segment,
selector=["conversation", self.variable_name],
id=self.variable_id,
name=self.variable_name,
description="LLM node-scoped memory",
)
return variable

def save(self) -> None:
variable = self._dump_variable()
with Session(db.engine) as session:
# Upsert by (id, conversation_id)
existing = session.scalar(
select(ConversationVariable).where(
ConversationVariable.id == self.variable_id,
ConversationVariable.conversation_id == self.conversation_id,
)
)
if existing:
existing.data = variable.model_dump_json()
else:
obj = ConversationVariable.from_variable(
app_id=self.app_id, conversation_id=self.conversation_id, variable=variable
)
session.add(obj)
session.commit()

# ------------ Public API expected by LLM node ------------
def get_history_prompt_messages(
self, *, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
self._load_if_needed()

# Optionally limit by message count (pairs flattened)
items: list[_HistoryItem] = list(self._history)
if message_limit and message_limit > 0:
# message_limit roughly means last N items (not pairs) to keep simple and efficient
items = items[-min(message_limit, len(items)) :]

def to_messages(hist: list[_HistoryItem]) -> list[PromptMessage]:
msgs: list[PromptMessage] = []
for it in hist:
if it.role == PromptMessageRole.USER.value:
# Persisted node memory only stores text; inject as plain text content
msgs.append(UserPromptMessage(content=it.text))
elif it.role == PromptMessageRole.ASSISTANT.value:
msgs.append(AssistantPromptMessage(content=it.text))
return msgs

messages = to_messages(items)
# Token-based pruning from oldest
if messages:
tokens = self.model_instance.get_llm_num_tokens(messages)
while tokens > max_token_limit and len(messages) > 1:
messages.pop(0)
tokens = self.model_instance.get_llm_num_tokens(messages)
return messages

def get_history_prompt_text(
self,
*,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
self._load_if_needed()
items: list[_HistoryItem] = list(self._history)
if message_limit and message_limit > 0:
items = items[-min(message_limit, len(items)) :]

# Build messages to reuse token counting logic
messages: list[PromptMessage] = []
for it in items:
role_name = (
PromptMessageRole.USER
if it.role == PromptMessageRole.USER.value
else (PromptMessageRole.ASSISTANT if it.role == PromptMessageRole.ASSISTANT.value else None)
)
if role_name is None:
continue
prefix = human_prefix if role_name == PromptMessageRole.USER else ai_prefix
messages.append(
UserPromptMessage(content=f"{prefix}: {it.text}")
if role_name == PromptMessageRole.USER
else AssistantPromptMessage(content=f"{prefix}: {it.text}")
)

if messages:
tokens = self.model_instance.get_llm_num_tokens(messages)
while tokens > max_token_limit and len(messages) > 1:
messages.pop(0)
tokens = self.model_instance.get_llm_num_tokens(messages)

# Convert back to the required text format
lines: list[str] = []
for m in messages:
if m.role == PromptMessageRole.USER:
prefix = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
prefix = ai_prefix
else:
continue
if isinstance(m.content, list):
# Only text content was saved in this minimal implementation
texts = [c.data for c in m.content if isinstance(c, TextPromptMessageContent)]
text = "\n".join(texts)
else:
text = str(m.content)
lines.append(f"{prefix}: {text}")
return "\n".join(lines)

def append_exchange(self, *, user_text: str | None, assistant_text: str | None) -> None:
self._load_if_needed()
if user_text:
self._history.append(_HistoryItem(role=PromptMessageRole.USER.value, text=user_text))
if assistant_text:
self._history.append(_HistoryItem(role=PromptMessageRole.ASSISTANT.value, text=assistant_text))

def clear(self) -> None:
self._history = []
self._loaded = True
self.save()
5 changes: 5 additions & 0 deletions api/core/prompt/entities/advanced_prompt_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ class WindowConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
# Memory scope: shared (default, uses TokenBufferMemory), or independent
# (per-node, persisted in ConversationVariable)
scope: Literal["shared", "independent"] = "shared"
# If true, clear the per-node memory after this node finishes execution (only applies to independent scope)
clear_after_execution: bool = False
23 changes: 23 additions & 0 deletions api/core/workflow/nodes/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.node_scoped_memory import NodeScopedMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
Expand Down Expand Up @@ -107,6 +108,28 @@ def fetch_memory(
return memory


def fetch_node_scoped_memory(
variable_pool: VariablePool,
*,
app_id: str,
node_id: str,
model_instance: ModelInstance,
) -> NodeScopedMemory | None:
"""Factory for per-node memory based on conversation scope.
Returns None if no conversation_id is present in the variable pool.
"""
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
return NodeScopedMemory(
app_id=app_id,
conversation_id=conversation_id_variable.value,
node_id=node_id,
model_instance=model_instance,
)


def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
Expand Down
58 changes: 46 additions & 12 deletions api/core/workflow/nodes/llm/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, Protocol

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
Expand Down Expand Up @@ -100,6 +99,24 @@
logger = logging.getLogger(__name__)


class ChatHistoryMemory(Protocol):
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]: ...

def get_history_prompt_text(
self,
*,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str: ...


class LLMNode(Node):
node_type = NodeType.LLM

Expand Down Expand Up @@ -215,13 +232,26 @@ def _run(self) -> Generator:
tenant_id=self.tenant_id,
)

# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self._node_data.memory,
model_instance=model_instance,
# fetch memory (shared) or node-scoped (independent)
independent_scope = (
self._node_data.memory and getattr(self._node_data.memory, "scope", "shared") == "independent"
)
node_memory = None
memory_shared = None
if independent_scope:
node_memory = llm_utils.fetch_node_scoped_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_id=self._node_id,
model_instance=model_instance,
)
else:
memory_shared = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self._node_data.memory,
model_instance=model_instance,
)

query: str | None = None
if self._node_data.memory:
Expand All @@ -235,7 +265,7 @@ def _run(self) -> Generator:
sys_query=query,
sys_files=files,
context=context,
memory=memory,
memory=(node_memory if independent_scope else memory_shared),
model_config=model_config,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
Expand Down Expand Up @@ -289,6 +319,10 @@ def _run(self) -> Generator:
else None
)

# Persist node-scoped memory if enabled
if independent_scope and node_memory:
node_memory.clear()

# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
Expand Down Expand Up @@ -756,7 +790,7 @@ def fetch_prompt_messages(
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
memory: ChatHistoryMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
Expand Down Expand Up @@ -1297,7 +1331,7 @@ def _calculate_rest_token(

def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory: ChatHistoryMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
Expand All @@ -1314,7 +1348,7 @@ def _handle_memory_chat_mode(

def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory: ChatHistoryMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
Expand Down
Loading
Loading