Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 84 additions & 6 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
Expand All @@ -72,7 +73,7 @@
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowNodeExecutionModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -580,7 +581,7 @@ def _handle_stop_event(

with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)

yield workflow_finish_resp
elif event.stopped_by in (
Expand All @@ -590,7 +591,7 @@ def _handle_stop_event(
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
self._save_message(session=session, trace_manager=trace_manager)

yield self._message_end_to_stream_response()

Expand All @@ -599,6 +600,7 @@ def _handle_advanced_chat_message_end_event(
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
Expand All @@ -616,7 +618,7 @@ def _handle_advanced_chat_message_end_event(

# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)

yield self._message_end_to_stream_response()

Expand Down Expand Up @@ -770,7 +772,13 @@ def _process_stream_response(
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()

def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session)

# If there are assistant files, remove markdown image links from answer
Expand Down Expand Up @@ -809,6 +817,14 @@ def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeSt

metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))

# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")

message_files = [
MessageFile(
message_id=message.id,
Expand All @@ -826,6 +842,68 @@ def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeSt
]
session.add_all(message_files)

# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)

def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)

if not node_execution:
return None

# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)

# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)

return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None

def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
Expand Down
3 changes: 3 additions & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class EasyUITaskState(TaskState):
"""

llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False


class WorkflowTaskState(TaskState):
Expand Down
18 changes: 18 additions & 0 deletions api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ def _process_stream_response(
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages

# Track streaming response times
if self._task_state.first_token_time is None:
self._task_state.first_token_time = time.perf_counter()
self._task_state.is_streaming_response = True
self._task_state.last_token_time = time.perf_counter()

# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
Expand Down Expand Up @@ -398,6 +404,18 @@ def _save_message(self, *, session: Session, trace_manager: TraceQueueManager |
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency

# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition self._task_state.is_streaming_response and self._task_state.first_token_time is redundant. Based on the logic in _process_stream_response, is_streaming_response is always True when first_token_time is set. You can simplify this condition to make the code more concise.

Suggested change
if self._task_state.is_streaming_response and self._task_state.first_token_time:
if self._task_state.first_token_time:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep the explicit check because the dual condition provides better protection.

start_time = self.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)

# Update metadata with the complete usage info
self._task_state.metadata.usage = usage

message.message_metadata = self._task_state.metadata.model_dump_json()

if trace_manager:
Expand Down
53 changes: 53 additions & 0 deletions api/core/ops/tencent_trace/span_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,59 @@ def build_message_span(
links=links,
)

@staticmethod
def build_message_llm_span(
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
) -> SpanData:
"""Build LLM span for message traces with detailed LLM attributes."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)

# Extract model information from `metadata`` or `message_data`
trace_metadata = trace_info.metadata or {}
message_data = trace_info.message_data or {}

model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
Comment on lines +238 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for extracting model_provider and model_name is a bit dense and could be refactored for better readability. Also, there's a small typo in the comment on line 234 (metadata`` should be metadata`). Consider breaking down the logic into more explicit steps.

Suggested change
model_provider = trace_metadata.get("ls_provider") or (
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
)
model_name = trace_metadata.get("ls_model_name") or (
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
)
model_provider = trace_metadata.get("ls_provider", "")
if not model_provider and isinstance(message_data, dict):
model_provider = message_data.get("model_provider", "")
model_name = trace_metadata.get("ls_model_name", "")
if not model_name and isinstance(message_data, dict):
model_name = message_data.get("model_id", "")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is purely a style preference I'd prefer to keep it as-is unless there's a functional issue.


inputs_str = str(trace_info.inputs or "")
outputs_str = str(trace_info.outputs or "")

attributes = {
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: str(model_name),
GEN_AI_PROVIDER: str(model_provider),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
GEN_AI_PROMPT: inputs_str,
GEN_AI_COMPLETION: outputs_str,
INPUT_VALUE: inputs_str,
OUTPUT_VALUE: outputs_str,
}

if trace_info.is_streaming_request:
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"

return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
)

@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
Expand Down
6 changes: 5 additions & 1 deletion api/core/ops/tencent_trace/tencent_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,13 @@ def message_trace(self, trace_info: MessageTraceInfo) -> None:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))

message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)

self.trace_client.add_span(message_span)

# Add LLM child span with detailed attributes
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
self.trace_client.add_span(llm_span)

self._record_message_llm_metrics(trace_info)

# Record trace duration for entry span
Expand Down
8 changes: 8 additions & 0 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,13 @@ def to_dict(self) -> dict[str, Any]:
"id": self.id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"model_provider": self.model_provider,
"model_id": self.model_id,
"inputs": self.inputs,
"query": self.query,
"message_tokens": self.message_tokens,
"answer_tokens": self.answer_tokens,
"provider_response_latency": self.provider_response_latency,
"total_price": self.total_price,
"message": self.message,
"answer": self.answer,
Expand All @@ -1268,8 +1272,12 @@ def from_dict(cls, data: dict[str, Any]) -> "Message":
id=data["id"],
app_id=data["app_id"],
conversation_id=data["conversation_id"],
model_provider=data.get("model_provider"),
model_id=data["model_id"],
inputs=data["inputs"],
message_tokens=data.get("message_tokens", 0),
answer_tokens=data.get("answer_tokens", 0),
provider_response_latency=data.get("provider_response_latency", 0.0),
total_price=data["total_price"],
query=data["query"],
message=data["message"],
Expand Down