mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 00:38:03 +08:00
- Update BaseNodeData import path to dify_graph.entities.base_node_data - Change NodeType.COMMAND/FILE_UPLOAD to BuiltinNodeTypes constants - Fix system_oauth_encryption -> system_encryption rename in commands - Remove tests for deleted agent runner modules - Fix Avatar: named import + string size API in collaboration files - Add missing skill feature deps: @monaco-editor/react, react-arborist, @tanstack/react-virtual - Fix frontend test mocks: add useUserProfile, useLeaderRestoreListener, next/navigation mock, and nodeOutputVars to expected payload Made-with: Cursor
2901 lines
121 KiB
Python
2901 lines
121 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
import re
|
|
import time
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
from functools import reduce
|
|
from pathlib import PurePosixPath
|
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
|
|
from sqlalchemy import select
|
|
|
|
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
|
from core.agent.patterns import StrategyFactory
|
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
|
from core.llm_generator.output_parser.errors import OutputParserError
|
|
from core.llm_generator.output_parser.file_ref import (
|
|
adapt_schema_for_sandbox_file_paths,
|
|
convert_sandbox_file_paths_in_output,
|
|
detect_file_path_fields,
|
|
)
|
|
from core.llm_generator.output_parser.structured_output import (
|
|
invoke_llm_with_structured_output,
|
|
)
|
|
from core.memory.base import BaseMemory
|
|
from core.model_manager import ModelInstance
|
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|
from core.sandbox import Sandbox
|
|
from core.sandbox.bash.session import MAX_OUTPUT_FILE_SIZE, MAX_OUTPUT_FILES, SandboxBashSession
|
|
from core.sandbox.entities.config import AppAssets
|
|
from core.skill.assembler import SkillDocumentAssembler
|
|
from core.skill.constants import SkillAttrs
|
|
from core.skill.entities.skill_bundle import SkillBundle
|
|
from core.skill.entities.skill_document import SkillDocument
|
|
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
|
from core.tools.__base.tool import Tool
|
|
from core.tools.signature import sign_tool_file, sign_upload_file
|
|
from core.tools.tool_file_manager import ToolFileManager
|
|
from core.tools.tool_manager import ToolManager
|
|
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
|
from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
|
from dify_graph.entities.graph_config import NodeConfigDict
|
|
from dify_graph.entities.tool_entities import ToolCallResult
|
|
from dify_graph.enums import (
|
|
BuiltinNodeTypes,
|
|
NodeType,
|
|
SystemVariableKey,
|
|
WorkflowNodeExecutionMetadataKey,
|
|
WorkflowNodeExecutionStatus,
|
|
)
|
|
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
|
|
from dify_graph.model_runtime.entities import (
|
|
ImagePromptMessageContent,
|
|
PromptMessage,
|
|
PromptMessageContentType,
|
|
TextPromptMessageContent,
|
|
)
|
|
from dify_graph.model_runtime.entities.llm_entities import (
|
|
LLMResult,
|
|
LLMResultChunk,
|
|
LLMResultChunkWithStructuredOutput,
|
|
LLMResultWithStructuredOutput,
|
|
LLMStructuredOutput,
|
|
LLMUsage,
|
|
)
|
|
from dify_graph.model_runtime.entities.message_entities import (
|
|
AssistantPromptMessage,
|
|
PromptMessageContentUnionTypes,
|
|
PromptMessageRole,
|
|
SystemPromptMessage,
|
|
UserPromptMessage,
|
|
)
|
|
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
|
from dify_graph.model_runtime.memory import PromptMessageMemory
|
|
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
|
from dify_graph.node_events import (
|
|
AgentLogEvent,
|
|
ModelInvokeCompletedEvent,
|
|
NodeEventBase,
|
|
NodeRunResult,
|
|
RunRetrieverResourceEvent,
|
|
StreamChunkEvent,
|
|
StreamCompletedEvent,
|
|
ThoughtChunkEvent,
|
|
ToolCallChunkEvent,
|
|
ToolResultChunkEvent,
|
|
)
|
|
from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent
|
|
from dify_graph.nodes.base.entities import VariableSelector
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
|
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
|
from dify_graph.nodes.protocols import HttpClientProtocol
|
|
from dify_graph.runtime import VariablePool
|
|
from dify_graph.variables import (
|
|
ArrayFileSegment,
|
|
ArrayPromptMessageSegment,
|
|
ArraySegment,
|
|
FileSegment,
|
|
NoneSegment,
|
|
ObjectSegment,
|
|
StringSegment,
|
|
)
|
|
from extensions.ext_database import db
|
|
from models.dataset import SegmentAttachmentBinding
|
|
from models.model import UploadFile
|
|
|
|
from . import llm_utils
|
|
from .entities import (
|
|
AgentContext,
|
|
AggregatedResult,
|
|
LLMGenerationData,
|
|
LLMNodeChatModelMessage,
|
|
LLMNodeCompletionModelPromptTemplate,
|
|
LLMNodeData,
|
|
LLMTraceSegment,
|
|
ModelTraceSegment,
|
|
PromptMessageContext,
|
|
StreamBuffers,
|
|
ThinkTagStreamParser,
|
|
ToolLogPayload,
|
|
ToolOutputState,
|
|
ToolTraceSegment,
|
|
TraceState,
|
|
)
|
|
from .exc import (
|
|
InvalidContextStructureError,
|
|
InvalidVariableTypeError,
|
|
LLMNodeError,
|
|
MemoryRolePrefixRequiredError,
|
|
NoPromptFoundError,
|
|
TemplateTypeNotSupportError,
|
|
VariableNotFoundError,
|
|
)
|
|
from .file_saver import FileSaverImpl, LLMFileSaver
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.file.models import File
|
|
from dify_graph.runtime import GraphRuntimeState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMNode(Node[LLMNodeData]):
|
|
node_type = BuiltinNodeTypes.LLM
|
|
|
|
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
|
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
|
|
|
# Instance attributes specific to LLMNode.
|
|
# Output variable for file
|
|
_file_outputs: list[File]
|
|
|
|
_llm_file_saver: LLMFileSaver
|
|
_credentials_provider: CredentialsProvider
|
|
_model_factory: ModelFactory
|
|
_model_instance: ModelInstance
|
|
_memory: PromptMessageMemory | None
|
|
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
config: NodeConfigDict,
|
|
graph_init_params: GraphInitParams,
|
|
graph_runtime_state: GraphRuntimeState,
|
|
*,
|
|
credentials_provider: CredentialsProvider,
|
|
model_factory: ModelFactory,
|
|
model_instance: ModelInstance,
|
|
http_client: HttpClientProtocol,
|
|
memory: PromptMessageMemory | None = None,
|
|
llm_file_saver: LLMFileSaver | None = None,
|
|
):
|
|
super().__init__(
|
|
id=id,
|
|
config=config,
|
|
graph_init_params=graph_init_params,
|
|
graph_runtime_state=graph_runtime_state,
|
|
)
|
|
# LLM file outputs, used for MultiModal outputs.
|
|
self._file_outputs = []
|
|
|
|
self._credentials_provider = credentials_provider
|
|
self._model_factory = model_factory
|
|
self._model_instance = model_instance
|
|
self._memory = memory
|
|
|
|
if llm_file_saver is None:
|
|
dify_ctx = self.require_dify_context()
|
|
llm_file_saver = FileSaverImpl(
|
|
user_id=dify_ctx.user_id,
|
|
tenant_id=dify_ctx.tenant_id,
|
|
http_client=http_client,
|
|
)
|
|
self._llm_file_saver = llm_file_saver
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|
|
|
|
def _run(self) -> Generator:
|
|
node_inputs: dict[str, Any] = {}
|
|
process_data: dict[str, Any] = {}
|
|
clean_text = ""
|
|
usage = LLMUsage.empty_usage()
|
|
finish_reason = None
|
|
reasoning_content = "" # Initialize as empty string for consistency
|
|
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
|
|
try:
|
|
# Parse prompt template to separate static messages and context references
|
|
prompt_template = self.node_data.prompt_template
|
|
static_messages, context_refs, template_order = self._parse_prompt_template()
|
|
|
|
# fetch variables and fetch values from variable pool
|
|
inputs = self._fetch_inputs(node_data=self.node_data)
|
|
|
|
# fetch jinja2 inputs
|
|
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
|
|
|
# merge inputs
|
|
inputs.update(jinja_inputs)
|
|
|
|
# fetch files
|
|
files = (
|
|
llm_utils.fetch_files(
|
|
variable_pool=variable_pool,
|
|
selector=self.node_data.vision.configs.variable_selector,
|
|
)
|
|
if self.node_data.vision.enabled
|
|
else []
|
|
)
|
|
|
|
if files:
|
|
node_inputs["#files#"] = [file.to_dict() for file in files]
|
|
|
|
# fetch context value
|
|
generator = self._fetch_context(node_data=self.node_data)
|
|
context = None
|
|
context_files: list[File] = []
|
|
for event in generator:
|
|
context = event.context
|
|
context_files = event.context_files or []
|
|
yield event
|
|
if context:
|
|
node_inputs["#context#"] = context
|
|
|
|
if context_files:
|
|
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
|
|
|
# fetch model config
|
|
model_instance = self._model_instance
|
|
model_name = model_instance.model_name
|
|
model_provider = model_instance.provider
|
|
model_stop = model_instance.stop
|
|
|
|
memory = llm_utils.fetch_memory(
|
|
variable_pool=variable_pool,
|
|
app_id=self.app_id,
|
|
tenant_id=self.tenant_id,
|
|
node_data_memory=self.node_data.memory,
|
|
model_instance=model_instance,
|
|
node_id=self._node_id,
|
|
)
|
|
|
|
query: str | None = None
|
|
if self.node_data.memory:
|
|
query = self.node_data.memory.query_prompt_template
|
|
if not query and (
|
|
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
|
):
|
|
query = query_variable.text
|
|
|
|
prompt_messages: Sequence[PromptMessage]
|
|
stop: Sequence[str] | None
|
|
if isinstance(prompt_template, list) and context_refs:
|
|
prompt_messages, stop = self._build_prompt_messages_with_context(
|
|
context_refs=context_refs,
|
|
template_order=template_order,
|
|
static_messages=static_messages,
|
|
query=query,
|
|
files=files,
|
|
context=context,
|
|
memory=memory,
|
|
model_instance=model_instance,
|
|
context_files=context_files,
|
|
)
|
|
else:
|
|
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
|
sys_query=query,
|
|
sys_files=files,
|
|
context=context,
|
|
memory=memory,
|
|
model_instance=model_instance,
|
|
stop=model_stop,
|
|
prompt_template=cast(
|
|
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
|
self.node_data.prompt_template,
|
|
),
|
|
memory_config=self.node_data.memory,
|
|
vision_enabled=self.node_data.vision.enabled,
|
|
vision_detail=self.node_data.vision.configs.detail,
|
|
variable_pool=variable_pool,
|
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
|
context_files=context_files,
|
|
sandbox=self.graph_runtime_state.sandbox,
|
|
)
|
|
|
|
# handle invoke result
|
|
generator = LLMNode.invoke_llm(
|
|
model_instance=model_instance,
|
|
prompt_messages=prompt_messages,
|
|
stop=stop,
|
|
user_id=self.require_dify_context().user_id,
|
|
structured_output_enabled=self.node_data.structured_output_enabled,
|
|
structured_output=self.node_data.structured_output,
|
|
file_saver=self._llm_file_saver,
|
|
file_outputs=self._file_outputs,
|
|
node_id=self._node_id,
|
|
node_type=self.node_type,
|
|
reasoning_format=self.node_data.reasoning_format,
|
|
)
|
|
|
|
# Variables for outputs
|
|
generation_data: LLMGenerationData | None = None
|
|
structured_output: LLMStructuredOutput | None = None
|
|
structured_output_schema: Mapping[str, Any] | None
|
|
structured_output_file_paths: list[str] = []
|
|
|
|
if self.node_data.structured_output_enabled:
|
|
if not self.node_data.structured_output:
|
|
raise LLMNodeError("structured_output_enabled is True but structured_output is not set")
|
|
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
|
|
if self.node_data.computer_use:
|
|
raise LLMNodeError("Structured output is not supported in computer use mode.")
|
|
else:
|
|
if detect_file_path_fields(raw_schema):
|
|
sandbox = self.graph_runtime_state.sandbox
|
|
if not sandbox:
|
|
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
|
|
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
|
|
raw_schema
|
|
)
|
|
else:
|
|
structured_output_schema = raw_schema
|
|
else:
|
|
structured_output_schema = None
|
|
|
|
if self.node_data.computer_use:
|
|
sandbox = self.graph_runtime_state.sandbox
|
|
if not sandbox:
|
|
raise LLMNodeError("computer use is enabled but no sandbox found")
|
|
tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies()
|
|
generator = self._invoke_llm_with_sandbox(
|
|
sandbox=sandbox,
|
|
model_instance=model_instance,
|
|
prompt_messages=prompt_messages,
|
|
stop=stop,
|
|
variable_pool=variable_pool,
|
|
tool_dependencies=tool_dependencies,
|
|
)
|
|
elif self.tool_call_enabled:
|
|
generator = self._invoke_llm_with_tools(
|
|
model_instance=model_instance,
|
|
prompt_messages=prompt_messages,
|
|
stop=stop,
|
|
files=files,
|
|
variable_pool=variable_pool,
|
|
node_inputs=node_inputs,
|
|
process_data=process_data,
|
|
)
|
|
else:
|
|
# Use traditional LLM invocation
|
|
generator = LLMNode.invoke_llm(
|
|
node_data_model=self._node_data.model,
|
|
model_instance=model_instance,
|
|
prompt_messages=prompt_messages,
|
|
stop=stop,
|
|
user_id=self.user_id,
|
|
structured_output_schema=structured_output_schema,
|
|
allow_file_path=bool(structured_output_file_paths),
|
|
file_saver=self._llm_file_saver,
|
|
file_outputs=self._file_outputs,
|
|
node_id=self._node_id,
|
|
node_type=self.node_type,
|
|
reasoning_format=self._node_data.reasoning_format,
|
|
)
|
|
|
|
(
|
|
clean_text,
|
|
reasoning_content,
|
|
generation_reasoning_content,
|
|
generation_clean_content,
|
|
usage,
|
|
finish_reason,
|
|
structured_output,
|
|
generation_data,
|
|
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
|
|
|
|
if structured_output and structured_output_file_paths:
|
|
sandbox = self.graph_runtime_state.sandbox
|
|
if not sandbox:
|
|
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
|
|
|
|
structured_output_value = structured_output.structured_output
|
|
if structured_output_value is None:
|
|
raise LLMNodeError("Structured output is empty")
|
|
|
|
resolved_count = 0
|
|
|
|
def resolve_file(path: str) -> File:
|
|
nonlocal resolved_count
|
|
if resolved_count >= MAX_OUTPUT_FILES:
|
|
raise LLMNodeError("Structured output files exceed the sandbox output limit")
|
|
resolved_count += 1
|
|
return self._resolve_sandbox_file_path(sandbox=sandbox, path=path)
|
|
|
|
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
|
|
output=structured_output_value,
|
|
file_path_fields=structured_output_file_paths,
|
|
file_resolver=resolve_file,
|
|
)
|
|
structured_output = LLMStructuredOutput(structured_output=converted_output)
|
|
if structured_output_files:
|
|
self._file_outputs.extend(structured_output_files)
|
|
|
|
# Extract variables from generation_data if available
|
|
if generation_data:
|
|
clean_text = generation_data.text
|
|
reasoning_content = ""
|
|
usage = generation_data.usage
|
|
finish_reason = generation_data.finish_reason
|
|
|
|
# Unified process_data building
|
|
process_data = {
|
|
"model_mode": self.node_data.model.mode,
|
|
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
|
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
|
|
),
|
|
"usage": jsonable_encoder(usage),
|
|
"finish_reason": finish_reason,
|
|
"model_provider": model_provider,
|
|
"model_name": model_name,
|
|
}
|
|
if self.tool_call_enabled and self._node_data.tools:
|
|
process_data["tools"] = [
|
|
{
|
|
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
|
|
"provider_name": tool.provider_name,
|
|
"tool_name": tool.tool_name,
|
|
}
|
|
for tool in self._node_data.tools
|
|
if tool.enabled
|
|
]
|
|
|
|
is_sandbox = self.graph_runtime_state.sandbox is not None
|
|
outputs = self._build_outputs(
|
|
is_sandbox=is_sandbox,
|
|
clean_text=clean_text,
|
|
reasoning_content=reasoning_content,
|
|
generation_reasoning_content=generation_reasoning_content,
|
|
generation_clean_content=generation_clean_content,
|
|
usage=usage,
|
|
finish_reason=finish_reason,
|
|
prompt_messages=prompt_messages,
|
|
generation_data=generation_data,
|
|
structured_output=structured_output,
|
|
)
|
|
|
|
# Send final chunk event to indicate streaming is complete
|
|
# For tool calls and sandbox, final events are already sent in _process_tool_outputs
|
|
if not self.tool_call_enabled and not self.node_data.computer_use:
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "text"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "content"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield ThoughtChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
|
|
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
}
|
|
|
|
if generation_data and generation_data.trace:
|
|
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
|
|
segment.model_dump() for segment in generation_data.trace
|
|
]
|
|
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=node_inputs,
|
|
process_data=process_data,
|
|
outputs=outputs,
|
|
metadata=metadata,
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
except ValueError as e:
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
error=str(e),
|
|
inputs=node_inputs,
|
|
process_data=process_data,
|
|
error_type=type(e).__name__,
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.exception("error while executing llm node")
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
error=str(e),
|
|
inputs=node_inputs,
|
|
process_data=process_data,
|
|
error_type=type(e).__name__,
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
|
|
def _build_outputs(
|
|
self,
|
|
*,
|
|
is_sandbox: bool,
|
|
clean_text: str,
|
|
reasoning_content: str,
|
|
generation_reasoning_content: str,
|
|
generation_clean_content: str,
|
|
usage: LLMUsage,
|
|
finish_reason: str | None,
|
|
prompt_messages: Sequence[PromptMessage],
|
|
generation_data: LLMGenerationData | None,
|
|
structured_output: LLMStructuredOutput | None,
|
|
) -> dict[str, Any]:
|
|
"""Build the outputs dictionary for the LLM node.
|
|
|
|
Two runtime modes produce different output shapes:
|
|
- **Classical** (is_sandbox=False): top-level ``text`` and ``reasoning_content``
|
|
are preserved for backward compatibility with existing users.
|
|
- **Sandbox** (is_sandbox=True): ``text`` and ``reasoning_content`` are omitted
|
|
from the top level because they duplicate fields inside ``generation``.
|
|
|
|
The ``generation`` field always carries the full structured representation
|
|
(content, reasoning, tool_calls, sequence) regardless of runtime mode.
|
|
|
|
Args:
|
|
is_sandbox: Whether the current runtime is sandbox mode.
|
|
clean_text: Processed text for outputs["text"]; may keep <think> tags for "tagged" format.
|
|
reasoning_content: Native model reasoning from the API response.
|
|
generation_reasoning_content: Reasoning for the generation field, extracted from <think>
|
|
tags via _split_reasoning (always tag-free). Falls back to reasoning_content
|
|
if empty (no <think> tags found).
|
|
generation_clean_content: Clean text for the generation field (always tag-free).
|
|
Differs from clean_text only when reasoning_format is "tagged".
|
|
usage: LLM usage statistics.
|
|
finish_reason: Finish reason from LLM.
|
|
prompt_messages: Prompt messages sent to the LLM.
|
|
generation_data: Multi-turn generation data from tool/sandbox invocation, or None.
|
|
structured_output: Structured output if enabled.
|
|
"""
|
|
# Common outputs shared by both runtimes
|
|
outputs: dict[str, Any] = {
|
|
"usage": jsonable_encoder(usage),
|
|
"finish_reason": finish_reason,
|
|
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
|
|
}
|
|
|
|
# Classical runtime keeps top-level text/reasoning_content for backward compatibility
|
|
if not is_sandbox:
|
|
outputs["text"] = clean_text
|
|
outputs["reasoning_content"] = reasoning_content
|
|
|
|
# Build generation field
|
|
if generation_data:
|
|
# Agent/sandbox runtime: generation_data captures multi-turn interactions
|
|
generation = {
|
|
"content": generation_data.text,
|
|
"reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...]
|
|
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
|
|
"sequence": generation_data.sequence,
|
|
}
|
|
files_to_output = list(generation_data.files)
|
|
# Merge auto-collected/structured-output files from self._file_outputs
|
|
if self._file_outputs:
|
|
existing_ids = {f.id for f in files_to_output}
|
|
files_to_output.extend(f for f in self._file_outputs if f.id not in existing_ids)
|
|
else:
|
|
# Classical runtime: use pre-computed generation-specific text pair,
|
|
# falling back to native model reasoning if no <think> tags were found.
|
|
generation_reasoning = generation_reasoning_content or reasoning_content
|
|
generation_content = generation_clean_content or clean_text
|
|
sequence: list[dict[str, Any]] = []
|
|
if generation_reasoning:
|
|
sequence = [
|
|
{"type": "reasoning", "index": 0},
|
|
{"type": "content", "start": 0, "end": len(generation_content)},
|
|
]
|
|
generation = {
|
|
"content": generation_content,
|
|
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
|
|
"tool_calls": [],
|
|
"sequence": sequence,
|
|
}
|
|
files_to_output = self._file_outputs
|
|
|
|
outputs["generation"] = generation
|
|
if files_to_output:
|
|
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
|
if structured_output:
|
|
outputs["structured_output"] = structured_output.structured_output
|
|
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def invoke_llm(
|
|
*,
|
|
model_instance: ModelInstance,
|
|
prompt_messages: Sequence[PromptMessage],
|
|
stop: Sequence[str] | None = None,
|
|
user_id: str,
|
|
structured_output_schema: Mapping[str, Any] | None,
|
|
allow_file_path: bool = False,
|
|
file_saver: LLMFileSaver,
|
|
file_outputs: list[File],
|
|
node_id: str,
|
|
node_type: NodeType,
|
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
|
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
|
model_parameters = model_instance.parameters
|
|
invoke_model_parameters = dict(model_parameters)
|
|
|
|
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
|
|
|
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
|
if structured_output_schema:
|
|
request_start_time = time.perf_counter()
|
|
|
|
invoke_result = invoke_llm_with_structured_output(
|
|
provider=model_instance.provider,
|
|
model_schema=model_schema,
|
|
model_instance=model_instance,
|
|
prompt_messages=prompt_messages,
|
|
json_schema=structured_output_schema,
|
|
model_parameters=invoke_model_parameters,
|
|
stop=list(stop or []),
|
|
user=user_id,
|
|
allow_file_path=allow_file_path,
|
|
)
|
|
else:
|
|
request_start_time = time.perf_counter()
|
|
|
|
invoke_result = model_instance.invoke_llm(
|
|
prompt_messages=list(prompt_messages),
|
|
model_parameters=invoke_model_parameters,
|
|
stop=list(stop or []),
|
|
stream=True,
|
|
user=user_id,
|
|
)
|
|
|
|
return LLMNode.handle_invoke_result(
|
|
invoke_result=invoke_result,
|
|
file_saver=file_saver,
|
|
file_outputs=file_outputs,
|
|
node_id=node_id,
|
|
node_type=node_type,
|
|
reasoning_format=reasoning_format,
|
|
request_start_time=request_start_time,
|
|
)
|
|
|
|
@staticmethod
|
|
def handle_invoke_result(
|
|
*,
|
|
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
|
file_saver: LLMFileSaver,
|
|
file_outputs: list[File],
|
|
node_id: str,
|
|
node_type: NodeType,
|
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
|
request_start_time: float | None = None,
|
|
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
|
# For blocking mode
|
|
if isinstance(invoke_result, LLMResult):
|
|
duration = None
|
|
if request_start_time is not None:
|
|
duration = time.perf_counter() - request_start_time
|
|
invoke_result.usage.latency = round(duration, 3)
|
|
event = LLMNode.handle_blocking_result(
|
|
invoke_result=invoke_result,
|
|
saver=file_saver,
|
|
file_outputs=file_outputs,
|
|
reasoning_format=reasoning_format,
|
|
request_latency=duration,
|
|
)
|
|
yield event
|
|
return
|
|
|
|
# For streaming mode
|
|
model = ""
|
|
prompt_messages: list[PromptMessage] = []
|
|
|
|
usage = LLMUsage.empty_usage()
|
|
finish_reason = None
|
|
full_text_buffer = io.StringIO()
|
|
think_parser = ThinkTagStreamParser()
|
|
reasoning_chunks: list[str] = []
|
|
|
|
# Initialize streaming metrics tracking
|
|
start_time = request_start_time if request_start_time is not None else time.perf_counter()
|
|
first_token_time = None
|
|
has_content = False
|
|
|
|
collected_structured_output = None # Collect structured_output from streaming chunks
|
|
# Consume the invoke result and handle generator exception
|
|
try:
|
|
for result in invoke_result:
|
|
if isinstance(result, LLMResultChunkWithStructuredOutput):
|
|
# Collect structured_output from the chunk
|
|
if result.structured_output is not None:
|
|
collected_structured_output = dict(result.structured_output)
|
|
yield result
|
|
if isinstance(result, LLMResultChunk):
|
|
contents = result.delta.message.content
|
|
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
|
contents=contents,
|
|
file_saver=file_saver,
|
|
file_outputs=file_outputs,
|
|
):
|
|
# Detect first token for TTFT calculation
|
|
if text_part and not has_content:
|
|
first_token_time = time.perf_counter()
|
|
has_content = True
|
|
|
|
full_text_buffer.write(text_part)
|
|
# Text output: always forward raw chunk (keep <think> tags intact)
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "text"],
|
|
chunk=text_part,
|
|
is_final=False,
|
|
)
|
|
|
|
# Generation output: split out thoughts, forward only non-thought content chunks
|
|
for kind, segment in think_parser.process(text_part):
|
|
if not segment:
|
|
if kind not in {"thought_start", "thought_end"}:
|
|
continue
|
|
|
|
if kind == "thought_start":
|
|
yield ThoughtStartChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought":
|
|
reasoning_chunks.append(segment)
|
|
yield ThoughtChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought_end":
|
|
yield ThoughtEndChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
else:
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "generation", "content"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
|
|
# Update the whole metadata
|
|
if not model and result.model:
|
|
model = result.model
|
|
if len(prompt_messages) == 0:
|
|
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
|
# What's the purpose of the line below?
|
|
prompt_messages = list(result.prompt_messages)
|
|
if usage.prompt_tokens == 0 and result.delta.usage:
|
|
usage = result.delta.usage
|
|
if finish_reason is None and result.delta.finish_reason:
|
|
finish_reason = result.delta.finish_reason
|
|
except OutputParserError as e:
|
|
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
|
|
|
for kind, segment in think_parser.flush():
|
|
if not segment and kind not in {"thought_start", "thought_end"}:
|
|
continue
|
|
if kind == "thought_start":
|
|
yield ThoughtStartChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought":
|
|
reasoning_chunks.append(segment)
|
|
yield ThoughtChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought_end":
|
|
yield ThoughtEndChunkEvent(
|
|
selector=[node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
else:
|
|
yield StreamChunkEvent(
|
|
selector=[node_id, "generation", "content"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
|
|
# Extract reasoning content from <think> tags in the main text
|
|
full_text = full_text_buffer.getvalue()
|
|
|
|
if reasoning_format == "tagged":
|
|
# Keep <think> tags in text for backward compatibility
|
|
clean_text = full_text
|
|
reasoning_content = "".join(reasoning_chunks)
|
|
else:
|
|
# Extract clean text and reasoning from <think> tags
|
|
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
|
if reasoning_chunks and not reasoning_content:
|
|
reasoning_content = "".join(reasoning_chunks)
|
|
|
|
# Calculate streaming metrics
|
|
end_time = time.perf_counter()
|
|
total_duration = end_time - start_time
|
|
usage.latency = round(total_duration, 3)
|
|
if has_content and first_token_time:
|
|
gen_ai_server_time_to_first_token = first_token_time - start_time
|
|
llm_streaming_time_to_generate = end_time - first_token_time
|
|
usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3)
|
|
usage.time_to_generate = round(llm_streaming_time_to_generate, 3)
|
|
|
|
yield ModelInvokeCompletedEvent(
|
|
# Use clean_text for separated mode, full_text for tagged mode
|
|
text=clean_text if reasoning_format == "separated" else full_text,
|
|
usage=usage,
|
|
finish_reason=finish_reason,
|
|
# Reasoning content for workflow variables and downstream nodes
|
|
reasoning_content=reasoning_content,
|
|
# Pass structured output if collected from streaming chunks
|
|
structured_output=collected_structured_output,
|
|
)
|
|
|
|
@staticmethod
|
|
def _image_file_to_markdown(file: File, /):
|
|
text_chunk = f"})"
|
|
return text_chunk
|
|
|
|
@classmethod
|
|
def _split_reasoning(
|
|
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Split reasoning content from text based on reasoning_format strategy.
|
|
|
|
Args:
|
|
text: Full text that may contain <think> blocks
|
|
reasoning_format: Strategy for handling reasoning content
|
|
- "separated": Remove <think> tags and return clean text + reasoning_content field
|
|
- "tagged": Keep <think> tags in text, return empty reasoning_content
|
|
|
|
Returns:
|
|
tuple of (clean_text, reasoning_content)
|
|
"""
|
|
|
|
if reasoning_format == "tagged":
|
|
return text, ""
|
|
|
|
# Find all <think>...</think> blocks (case-insensitive)
|
|
matches = cls._THINK_PATTERN.findall(text)
|
|
|
|
# Extract reasoning content from all <think> blocks
|
|
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
|
|
|
# Remove all <think>...</think> blocks from original text
|
|
clean_text = cls._THINK_PATTERN.sub("", text)
|
|
|
|
# Clean up extra whitespace
|
|
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
|
|
|
# Separated mode: always return clean text and reasoning_content
|
|
return clean_text, reasoning_content or ""
|
|
|
|
def _transform_chat_messages(
|
|
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
|
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
|
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
|
if messages.edition_type == "jinja2" and messages.jinja2_text:
|
|
messages.text = messages.jinja2_text
|
|
|
|
return messages
|
|
|
|
for message in messages:
|
|
if message.edition_type == "jinja2" and message.jinja2_text:
|
|
message.text = message.jinja2_text
|
|
|
|
return messages
|
|
|
|
def _parse_prompt_template(
|
|
self,
|
|
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
|
"""
|
|
Parse prompt_template to separate static messages and context references.
|
|
|
|
Returns:
|
|
Tuple of (static_messages, context_refs, template_order)
|
|
- static_messages: list of LLMNodeChatModelMessage
|
|
- context_refs: list of PromptMessageContext
|
|
- template_order: list of (index, type) tuples preserving original order
|
|
"""
|
|
prompt_template = self.node_data.prompt_template
|
|
static_messages: list[LLMNodeChatModelMessage] = []
|
|
context_refs: list[PromptMessageContext] = []
|
|
template_order: list[tuple[int, str]] = []
|
|
|
|
if isinstance(prompt_template, list):
|
|
for idx, item in enumerate(prompt_template):
|
|
if isinstance(item, PromptMessageContext):
|
|
context_refs.append(item)
|
|
template_order.append((idx, "context"))
|
|
else:
|
|
static_messages.append(item)
|
|
template_order.append((idx, "static"))
|
|
# Transform static messages for jinja2
|
|
if static_messages:
|
|
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
|
|
|
return static_messages, context_refs, template_order
|
|
|
|
def _build_prompt_messages_with_context(
|
|
self,
|
|
*,
|
|
context_refs: list[PromptMessageContext],
|
|
template_order: list[tuple[int, str]],
|
|
static_messages: list[LLMNodeChatModelMessage],
|
|
query: str | None,
|
|
files: Sequence[File],
|
|
context: str | None,
|
|
memory: BaseMemory | None,
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
context_files: list[File],
|
|
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
|
"""
|
|
Build prompt messages by combining static messages and context references in DSL order.
|
|
|
|
Returns:
|
|
Tuple of (prompt_messages, stop_sequences)
|
|
"""
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
|
|
# Process messages in DSL order: iterate once and handle each type directly
|
|
combined_messages: list[PromptMessage] = []
|
|
context_idx = 0
|
|
static_idx = 0
|
|
|
|
for _, type_ in template_order:
|
|
if type_ == "context":
|
|
# Handle context reference
|
|
ctx_ref = context_refs[context_idx]
|
|
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
|
if ctx_var is None:
|
|
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
|
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
|
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
|
# Restore multimodal content (base64/url) that was truncated when saving context
|
|
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
|
|
combined_messages.extend(restored_messages)
|
|
context_idx += 1
|
|
else:
|
|
# Handle static message
|
|
static_msg = static_messages[static_idx]
|
|
processed_msgs = LLMNode.handle_list_messages(
|
|
messages=[static_msg],
|
|
context=context,
|
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
|
|
variable_pool=variable_pool,
|
|
vision_detail_config=self.node_data.vision.configs.detail,
|
|
sandbox=self.graph_runtime_state.sandbox,
|
|
)
|
|
combined_messages.extend(processed_msgs)
|
|
static_idx += 1
|
|
|
|
# Append memory messages
|
|
memory_messages = _handle_memory_chat_mode(
|
|
memory=memory,
|
|
memory_config=self.node_data.memory,
|
|
model_config=model_config,
|
|
)
|
|
combined_messages.extend(memory_messages)
|
|
|
|
# Append current query if provided
|
|
if query:
|
|
query_message = LLMNodeChatModelMessage(
|
|
text=query,
|
|
role=PromptMessageRole.USER,
|
|
edition_type="basic",
|
|
)
|
|
query_msgs = LLMNode.handle_list_messages(
|
|
messages=[query_message],
|
|
context="",
|
|
jinja2_variables=[],
|
|
variable_pool=variable_pool,
|
|
vision_detail_config=self.node_data.vision.configs.detail,
|
|
)
|
|
combined_messages.extend(query_msgs)
|
|
|
|
# Handle files (sys_files and context_files)
|
|
combined_messages = self._append_files_to_messages(
|
|
messages=combined_messages,
|
|
sys_files=files,
|
|
context_files=context_files,
|
|
model_config=model_config,
|
|
)
|
|
|
|
# Filter empty messages and get stop sequences
|
|
combined_messages = self._filter_messages(combined_messages, model_config)
|
|
stop = self._get_stop_sequences(model_config)
|
|
|
|
return combined_messages, stop
|
|
|
|
def _append_files_to_messages(
|
|
self,
|
|
*,
|
|
messages: list[PromptMessage],
|
|
sys_files: Sequence[File],
|
|
context_files: list[File],
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
) -> list[PromptMessage]:
|
|
"""Append sys_files and context_files to messages."""
|
|
vision_enabled = self.node_data.vision.enabled
|
|
vision_detail = self.node_data.vision.configs.detail
|
|
|
|
# Handle sys_files (will be deprecated later)
|
|
if vision_enabled and sys_files:
|
|
file_prompts = [
|
|
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
|
|
]
|
|
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
|
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
|
else:
|
|
messages.append(UserPromptMessage(content=file_prompts))
|
|
|
|
# Handle context_files
|
|
if vision_enabled and context_files:
|
|
file_prompts = [
|
|
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
for file in context_files
|
|
]
|
|
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
|
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
|
else:
|
|
messages.append(UserPromptMessage(content=file_prompts))
|
|
|
|
return messages
|
|
|
|
def _filter_messages(
|
|
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
|
) -> list[PromptMessage]:
|
|
"""Filter empty messages and unsupported content types."""
|
|
filtered_messages: list[PromptMessage] = []
|
|
|
|
for message in messages:
|
|
if isinstance(message.content, list):
|
|
filtered_content: list[PromptMessageContentUnionTypes] = []
|
|
for content_item in message.content:
|
|
# Skip non-text content if features are not defined
|
|
if not model_config.model_schema.features:
|
|
if content_item.type != PromptMessageContentType.TEXT:
|
|
continue
|
|
filtered_content.append(content_item)
|
|
continue
|
|
|
|
# Skip content if corresponding feature is not supported
|
|
feature_map = {
|
|
PromptMessageContentType.IMAGE: ModelFeature.VISION,
|
|
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
|
|
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
|
|
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
|
|
}
|
|
required_feature = feature_map.get(content_item.type)
|
|
if required_feature and required_feature not in model_config.model_schema.features:
|
|
continue
|
|
filtered_content.append(content_item)
|
|
|
|
# Simplify single text content
|
|
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
|
|
message.content = filtered_content[0].data
|
|
else:
|
|
message.content = filtered_content
|
|
|
|
if not message.is_empty():
|
|
filtered_messages.append(message)
|
|
|
|
if not filtered_messages:
|
|
raise NoPromptFoundError(
|
|
"No prompt found in the LLM configuration. "
|
|
"Please ensure a prompt is properly configured before proceeding."
|
|
)
|
|
|
|
return filtered_messages
|
|
|
|
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
|
|
"""Get stop sequences from model config."""
|
|
return model_config.stop
|
|
|
|
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
|
variables: dict[str, Any] = {}
|
|
|
|
if not node_data.prompt_config:
|
|
return variables
|
|
|
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
variable_name = variable_selector.variable
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
|
if variable is None:
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
|
|
|
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
|
"""
|
|
Parse dict into string
|
|
"""
|
|
# check if it's a context structure
|
|
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
|
return str(input_dict["content"])
|
|
|
|
# else, parse the dict
|
|
try:
|
|
return json.dumps(input_dict, ensure_ascii=False)
|
|
except Exception:
|
|
return str(input_dict)
|
|
|
|
if isinstance(variable, ArraySegment):
|
|
result = ""
|
|
for item in variable.value:
|
|
if isinstance(item, dict):
|
|
result += parse_dict(item)
|
|
else:
|
|
result += str(item)
|
|
result += "\n"
|
|
value = result.strip()
|
|
elif isinstance(variable, ObjectSegment):
|
|
value = parse_dict(variable.value)
|
|
else:
|
|
value = variable.text
|
|
|
|
variables[variable_name] = value
|
|
|
|
return variables
|
|
|
|
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
|
|
inputs = {}
|
|
prompt_template = node_data.prompt_template
|
|
|
|
variable_selectors = []
|
|
if isinstance(prompt_template, list):
|
|
for prompt in prompt_template:
|
|
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
|
elif isinstance(prompt_template, CompletionModelPromptTemplate):
|
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
|
|
for variable_selector in variable_selectors:
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
|
if variable is None:
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
|
if isinstance(variable, NoneSegment):
|
|
inputs[variable_selector.variable] = ""
|
|
inputs[variable_selector.variable] = variable.to_object()
|
|
|
|
memory = node_data.memory
|
|
if memory and memory.query_prompt_template:
|
|
query_variable_selectors = VariableTemplateParser(
|
|
template=memory.query_prompt_template
|
|
).extract_variable_selectors()
|
|
for variable_selector in query_variable_selectors:
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
|
if variable is None:
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
|
if isinstance(variable, NoneSegment):
|
|
continue
|
|
inputs[variable_selector.variable] = variable.to_object()
|
|
|
|
return inputs
|
|
|
|
def _fetch_context(self, node_data: LLMNodeData):
|
|
if not node_data.context.enabled:
|
|
return
|
|
|
|
if not node_data.context.variable_selector:
|
|
return
|
|
|
|
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
|
if context_value_variable:
|
|
if isinstance(context_value_variable, StringSegment):
|
|
yield RunRetrieverResourceEvent(
|
|
retriever_resources=[], context=context_value_variable.value, context_files=[]
|
|
)
|
|
elif isinstance(context_value_variable, ArraySegment):
|
|
context_str = ""
|
|
original_retriever_resource: list[dict[str, Any]] = []
|
|
context_files: list[File] = []
|
|
for item in context_value_variable.value:
|
|
if isinstance(item, str):
|
|
context_str += item + "\n"
|
|
else:
|
|
if "content" not in item:
|
|
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
|
|
|
if item.get("summary"):
|
|
context_str += item["summary"] + "\n"
|
|
context_str += item["content"] + "\n"
|
|
|
|
retriever_resource = self._convert_to_original_retriever_resource(item)
|
|
if retriever_resource:
|
|
original_retriever_resource.append(retriever_resource)
|
|
segment_id = retriever_resource.get("segment_id")
|
|
if not segment_id:
|
|
continue
|
|
attachments_with_bindings = db.session.execute(
|
|
select(SegmentAttachmentBinding, UploadFile)
|
|
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
|
.where(
|
|
SegmentAttachmentBinding.segment_id == segment_id,
|
|
)
|
|
).all()
|
|
if attachments_with_bindings:
|
|
for _, upload_file in attachments_with_bindings:
|
|
attachment_info = File(
|
|
id=upload_file.id,
|
|
filename=upload_file.name,
|
|
extension="." + upload_file.extension,
|
|
mime_type=upload_file.mime_type,
|
|
tenant_id=self.require_dify_context().tenant_id,
|
|
type=FileType.IMAGE,
|
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
remote_url=upload_file.source_url,
|
|
related_id=upload_file.id,
|
|
size=upload_file.size,
|
|
storage_key=upload_file.key,
|
|
url=sign_upload_file(upload_file.id, upload_file.extension),
|
|
)
|
|
context_files.append(attachment_info)
|
|
yield RunRetrieverResourceEvent(
|
|
retriever_resources=original_retriever_resource,
|
|
context=context_str.strip(),
|
|
context_files=context_files,
|
|
)
|
|
|
|
def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None:
|
|
if (
|
|
"metadata" in context_dict
|
|
and "_source" in context_dict["metadata"]
|
|
and context_dict["metadata"]["_source"] == "knowledge"
|
|
):
|
|
metadata = context_dict.get("metadata", {})
|
|
|
|
return {
|
|
"position": metadata.get("position"),
|
|
"dataset_id": metadata.get("dataset_id"),
|
|
"dataset_name": metadata.get("dataset_name"),
|
|
"document_id": metadata.get("document_id"),
|
|
"document_name": metadata.get("document_name"),
|
|
"data_source_type": metadata.get("data_source_type"),
|
|
"segment_id": metadata.get("segment_id"),
|
|
"retriever_from": metadata.get("retriever_from"),
|
|
"score": metadata.get("score"),
|
|
"hit_count": metadata.get("segment_hit_count"),
|
|
"word_count": metadata.get("segment_word_count"),
|
|
"segment_position": metadata.get("segment_position"),
|
|
"index_node_hash": metadata.get("segment_index_node_hash"),
|
|
"content": context_dict.get("content"),
|
|
"page": metadata.get("page"),
|
|
"doc_metadata": metadata.get("doc_metadata"),
|
|
"files": context_dict.get("files"),
|
|
"summary": context_dict.get("summary"),
|
|
}
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def fetch_prompt_messages(
|
|
*,
|
|
sys_query: str | None = None,
|
|
sys_files: Sequence[File],
|
|
context: str | None = None,
|
|
memory: BaseMemory | None = None,
|
|
model_instance: ModelInstance,
|
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
|
stop: Sequence[str] | None = None,
|
|
memory_config: MemoryConfig | None = None,
|
|
vision_enabled: bool = False,
|
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
|
variable_pool: VariablePool,
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
context_files: list[File] | None = None,
|
|
sandbox: Sandbox | None = None,
|
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
|
prompt_messages: list[PromptMessage] = []
|
|
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
|
|
|
if isinstance(prompt_template, list):
|
|
# For chat model
|
|
prompt_messages.extend(
|
|
LLMNode.handle_list_messages(
|
|
messages=prompt_template,
|
|
context=context,
|
|
jinja2_variables=jinja2_variables,
|
|
variable_pool=variable_pool,
|
|
vision_detail_config=vision_detail,
|
|
sandbox=sandbox,
|
|
)
|
|
)
|
|
|
|
# Get memory messages for chat mode
|
|
memory_messages = _handle_memory_chat_mode(
|
|
memory=memory,
|
|
memory_config=memory_config,
|
|
model_instance=model_instance,
|
|
)
|
|
# Extend prompt_messages with memory messages
|
|
prompt_messages.extend(memory_messages)
|
|
|
|
# Add current query to the prompt messages
|
|
if sys_query:
|
|
message = LLMNodeChatModelMessage(
|
|
text=sys_query,
|
|
role=PromptMessageRole.USER,
|
|
edition_type="basic",
|
|
)
|
|
prompt_messages.extend(
|
|
LLMNode.handle_list_messages(
|
|
messages=[message],
|
|
context="",
|
|
jinja2_variables=[],
|
|
variable_pool=variable_pool,
|
|
vision_detail_config=vision_detail,
|
|
)
|
|
)
|
|
|
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
|
# For completion model
|
|
prompt_messages.extend(
|
|
_handle_completion_template(
|
|
template=prompt_template,
|
|
context=context,
|
|
jinja2_variables=jinja2_variables,
|
|
variable_pool=variable_pool,
|
|
)
|
|
)
|
|
|
|
# Get memory text for completion model
|
|
memory_text = _handle_memory_completion_mode(
|
|
memory=memory,
|
|
memory_config=memory_config,
|
|
model_instance=model_instance,
|
|
)
|
|
# Insert histories into the prompt
|
|
prompt_content = prompt_messages[0].content
|
|
# For issue #11247 - Check if prompt content is a string or a list
|
|
prompt_content_type = type(prompt_content)
|
|
if prompt_content_type == str:
|
|
prompt_content = str(prompt_content)
|
|
if "#histories#" in prompt_content:
|
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
|
else:
|
|
prompt_content = memory_text + "\n" + prompt_content
|
|
prompt_messages[0].content = prompt_content
|
|
elif prompt_content_type == list:
|
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
|
for content_item in prompt_content:
|
|
if content_item.type == PromptMessageContentType.TEXT:
|
|
if "#histories#" in content_item.data:
|
|
content_item.data = content_item.data.replace("#histories#", memory_text)
|
|
else:
|
|
content_item.data = memory_text + "\n" + content_item.data
|
|
else:
|
|
raise ValueError("Invalid prompt content type")
|
|
|
|
# Add current query to the prompt message
|
|
if sys_query:
|
|
if prompt_content_type == str:
|
|
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
|
prompt_messages[0].content = prompt_content
|
|
elif prompt_content_type == list:
|
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
|
for content_item in prompt_content:
|
|
if content_item.type == PromptMessageContentType.TEXT:
|
|
content_item.data = sys_query + "\n" + content_item.data
|
|
else:
|
|
raise ValueError("Invalid prompt content type")
|
|
else:
|
|
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
|
|
|
# The sys_files will be deprecated later
|
|
if vision_enabled and sys_files:
|
|
file_prompts = []
|
|
for file in sys_files:
|
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
file_prompts.append(file_prompt)
|
|
# If last prompt is a user prompt, add files into its contents,
|
|
# otherwise append a new user prompt
|
|
if (
|
|
len(prompt_messages) > 0
|
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
|
and isinstance(prompt_messages[-1].content, list)
|
|
):
|
|
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
|
else:
|
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
|
|
|
# The context_files
|
|
if vision_enabled and context_files:
|
|
file_prompts = []
|
|
for file in context_files:
|
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
|
file_prompts.append(file_prompt)
|
|
# If last prompt is a user prompt, add files into its contents,
|
|
# otherwise append a new user prompt
|
|
if (
|
|
len(prompt_messages) > 0
|
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
|
and isinstance(prompt_messages[-1].content, list)
|
|
):
|
|
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
|
else:
|
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
|
|
|
# Remove empty messages and filter unsupported content
|
|
filtered_prompt_messages = []
|
|
for prompt_message in prompt_messages:
|
|
if isinstance(prompt_message.content, list):
|
|
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
|
for content_item in prompt_message.content:
|
|
# Skip content if features are not defined
|
|
if not model_schema.features:
|
|
if content_item.type != PromptMessageContentType.TEXT:
|
|
continue
|
|
prompt_message_content.append(content_item)
|
|
continue
|
|
|
|
# Skip content if corresponding feature is not supported
|
|
if (
|
|
(
|
|
content_item.type == PromptMessageContentType.IMAGE
|
|
and ModelFeature.VISION not in model_schema.features
|
|
)
|
|
or (
|
|
content_item.type == PromptMessageContentType.DOCUMENT
|
|
and ModelFeature.DOCUMENT not in model_schema.features
|
|
)
|
|
or (
|
|
content_item.type == PromptMessageContentType.VIDEO
|
|
and ModelFeature.VIDEO not in model_schema.features
|
|
)
|
|
or (
|
|
content_item.type == PromptMessageContentType.AUDIO
|
|
and ModelFeature.AUDIO not in model_schema.features
|
|
)
|
|
):
|
|
continue
|
|
prompt_message_content.append(content_item)
|
|
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
|
prompt_message.content = prompt_message_content[0].data
|
|
else:
|
|
prompt_message.content = prompt_message_content
|
|
if prompt_message.is_empty():
|
|
continue
|
|
filtered_prompt_messages.append(prompt_message)
|
|
|
|
if len(filtered_prompt_messages) == 0:
|
|
raise NoPromptFoundError(
|
|
"No prompt found in the LLM configuration. "
|
|
"Please ensure a prompt is properly configured before proceeding."
|
|
)
|
|
|
|
return filtered_prompt_messages, stop
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: LLMNodeData,
|
|
) -> Mapping[str, Sequence[str]]:
|
|
# graph_config is not used in this node type
|
|
_ = graph_config # Explicitly mark as unused
|
|
prompt_template = node_data.prompt_template
|
|
variable_selectors = []
|
|
prompt_context_selectors: list[Sequence[str]] = []
|
|
if isinstance(prompt_template, list):
|
|
for item in prompt_template:
|
|
# Check PromptMessageContext first (same order as _parse_prompt_template)
|
|
# This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
|
|
if isinstance(item, PromptMessageContext):
|
|
if len(item.value_selector) >= 2:
|
|
prompt_context_selectors.append(item.value_selector)
|
|
elif isinstance(item, LLMNodeChatModelMessage):
|
|
variable_template_parser = VariableTemplateParser(template=item.text)
|
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
|
if prompt_template.edition_type != "jinja2":
|
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
|
else:
|
|
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
|
|
|
variable_mapping: dict[str, Any] = {}
|
|
for variable_selector in variable_selectors:
|
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
|
for context_selector in prompt_context_selectors:
|
|
variable_key = f"#{'.'.join(context_selector)}#"
|
|
variable_mapping[variable_key] = list(context_selector)
|
|
|
|
memory = node_data.memory
|
|
if memory and memory.query_prompt_template:
|
|
query_variable_selectors = VariableTemplateParser(
|
|
template=memory.query_prompt_template
|
|
).extract_variable_selectors()
|
|
for variable_selector in query_variable_selectors:
|
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
|
if node_data.context.enabled:
|
|
variable_mapping["#context#"] = node_data.context.variable_selector
|
|
|
|
if node_data.vision.enabled:
|
|
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
|
|
|
if node_data.memory:
|
|
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
|
|
|
if node_data.prompt_config:
|
|
enable_jinja = False
|
|
|
|
if isinstance(prompt_template, list):
|
|
for item in prompt_template:
|
|
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
|
enable_jinja = True
|
|
break
|
|
else:
|
|
enable_jinja = True
|
|
|
|
if enable_jinja:
|
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
|
|
|
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
|
|
|
return variable_mapping
|
|
|
|
@classmethod
|
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
|
return {
|
|
"type": "llm",
|
|
"config": {
|
|
"prompt_templates": {
|
|
"chat_model": {
|
|
"prompts": [
|
|
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
|
|
]
|
|
},
|
|
"completion_model": {
|
|
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
|
"prompt": {
|
|
"text": "Here are the chat histories between human and assistant, inside "
|
|
"<histories></histories> XML tags.\n\n<histories>\n{{"
|
|
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
|
|
"edition_type": "basic",
|
|
},
|
|
"stop": ["Human:"],
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def handle_list_messages(
|
|
*,
|
|
messages: Sequence[LLMNodeChatModelMessage],
|
|
context: str | None,
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
variable_pool: VariablePool,
|
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
|
sandbox: Sandbox | None = None,
|
|
) -> Sequence[PromptMessage]:
|
|
prompt_messages: list[PromptMessage] = []
|
|
|
|
bundle: SkillBundle | None = None
|
|
if sandbox:
|
|
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
|
|
|
|
for message in messages:
|
|
if message.edition_type == "jinja2":
|
|
result_text = _render_jinja2_message(
|
|
template=message.jinja2_text or "",
|
|
jinja2_variables=jinja2_variables,
|
|
variable_pool=variable_pool,
|
|
)
|
|
|
|
if bundle is not None:
|
|
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
|
|
document=SkillDocument(
|
|
skill_id="anonymous", content=result_text, metadata=message.metadata or {}
|
|
),
|
|
base_path=AppAssets.PATH,
|
|
)
|
|
result_text = skill_entry.content
|
|
|
|
prompt_message = _combine_message_content_with_role(
|
|
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
|
)
|
|
prompt_messages.append(prompt_message)
|
|
else:
|
|
if context:
|
|
template = message.text.replace("{#context#}", context)
|
|
else:
|
|
template = message.text
|
|
segment_group = variable_pool.convert_template(template)
|
|
|
|
file_contents = []
|
|
for segment in segment_group.value:
|
|
if isinstance(segment, ArrayFileSegment):
|
|
for file in segment.value:
|
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
file_content = file_manager.to_prompt_message_content(
|
|
file, image_detail_config=vision_detail_config
|
|
)
|
|
file_contents.append(file_content)
|
|
elif isinstance(segment, FileSegment):
|
|
file = segment.value
|
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
|
file_content = file_manager.to_prompt_message_content(
|
|
file, image_detail_config=vision_detail_config
|
|
)
|
|
file_contents.append(file_content)
|
|
|
|
plain_text = segment_group.text
|
|
|
|
if plain_text and bundle is not None:
|
|
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
|
|
document=SkillDocument(
|
|
skill_id="anonymous", content=plain_text, metadata=message.metadata or {}
|
|
),
|
|
base_path=AppAssets.PATH,
|
|
)
|
|
plain_text = skill_entry.content
|
|
|
|
if plain_text:
|
|
prompt_message = _combine_message_content_with_role(
|
|
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
|
)
|
|
prompt_messages.append(prompt_message)
|
|
|
|
if file_contents:
|
|
# Create message with image contents
|
|
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
|
prompt_messages.append(prompt_message)
|
|
|
|
return prompt_messages
|
|
|
|
@staticmethod
|
|
def handle_blocking_result(
|
|
*,
|
|
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
|
saver: LLMFileSaver,
|
|
file_outputs: list[File],
|
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
|
request_latency: float | None = None,
|
|
) -> ModelInvokeCompletedEvent:
|
|
buffer = io.StringIO()
|
|
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
|
contents=invoke_result.message.content,
|
|
file_saver=saver,
|
|
file_outputs=file_outputs,
|
|
):
|
|
buffer.write(text_part)
|
|
|
|
# Extract reasoning content from <think> tags in the main text
|
|
full_text = buffer.getvalue()
|
|
|
|
if reasoning_format == "tagged":
|
|
# Keep <think> tags in text for backward compatibility
|
|
clean_text = full_text
|
|
reasoning_content = ""
|
|
else:
|
|
# Extract clean text and reasoning from <think> tags
|
|
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
|
|
|
event = ModelInvokeCompletedEvent(
|
|
# Use clean_text for separated mode, full_text for tagged mode
|
|
text=clean_text if reasoning_format == "separated" else full_text,
|
|
usage=invoke_result.usage,
|
|
finish_reason=None,
|
|
# Reasoning content for workflow variables and downstream nodes
|
|
reasoning_content=reasoning_content,
|
|
# Pass structured output if enabled
|
|
structured_output=getattr(invoke_result, "structured_output", None),
|
|
)
|
|
if request_latency is not None:
|
|
event.usage.latency = round(request_latency, 3)
|
|
return event
|
|
|
|
@staticmethod
|
|
def save_multimodal_image_output(
|
|
*,
|
|
content: ImagePromptMessageContent,
|
|
file_saver: LLMFileSaver,
|
|
) -> File:
|
|
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
|
|
|
There are two kinds of multimodal outputs:
|
|
|
|
- Inlined data encoded in base64, which would be saved to storage directly.
|
|
- Remote files referenced by an url, which would be downloaded and then saved to storage.
|
|
|
|
Currently, only image files are supported.
|
|
"""
|
|
if content.url != "":
|
|
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
|
|
else:
|
|
saved_file = file_saver.save_binary_string(
|
|
data=base64.b64decode(content.base64_data),
|
|
mime_type=content.mime_type,
|
|
file_type=FileType.IMAGE,
|
|
)
|
|
return saved_file
|
|
|
|
@staticmethod
|
|
def _normalize_sandbox_file_path(path: str) -> str:
|
|
raw = path.strip()
|
|
if not raw:
|
|
raise LLMNodeError("Sandbox file path must not be empty")
|
|
|
|
sandbox_path = PurePosixPath(raw)
|
|
if any(part == ".." for part in sandbox_path.parts):
|
|
raise LLMNodeError("Sandbox file path must not contain '..'")
|
|
|
|
normalized = str(sandbox_path)
|
|
if normalized in {".", ""}:
|
|
raise LLMNodeError("Sandbox file path is invalid")
|
|
|
|
return normalized
|
|
|
|
def _resolve_sandbox_file_path(self, *, sandbox: Sandbox, path: str) -> File:
|
|
normalized_path = self._normalize_sandbox_file_path(path)
|
|
filename = os.path.basename(normalized_path)
|
|
if not filename:
|
|
raise LLMNodeError("Sandbox file path must point to a file")
|
|
|
|
try:
|
|
file_content = sandbox.vm.download_file(normalized_path)
|
|
except Exception as exc:
|
|
raise LLMNodeError(f"Sandbox file not found: {normalized_path}") from exc
|
|
|
|
file_binary = file_content.getvalue()
|
|
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
|
|
raise LLMNodeError(f"Sandbox file exceeds size limit: {normalized_path}")
|
|
|
|
mime_type, _ = mimetypes.guess_type(filename)
|
|
if not mime_type:
|
|
mime_type = "application/octet-stream"
|
|
|
|
tool_file_manager = ToolFileManager()
|
|
tool_file = tool_file_manager.create_file_by_raw(
|
|
user_id=self.user_id,
|
|
tenant_id=self.tenant_id,
|
|
conversation_id=None,
|
|
file_binary=file_binary,
|
|
mimetype=mime_type,
|
|
filename=filename,
|
|
)
|
|
|
|
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
|
url = sign_tool_file(tool_file.id, extension)
|
|
file_type = self._get_file_type_from_mime(mime_type)
|
|
|
|
return File(
|
|
id=tool_file.id,
|
|
tenant_id=self.tenant_id,
|
|
type=file_type,
|
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
|
filename=filename,
|
|
extension=extension,
|
|
mime_type=mime_type,
|
|
size=len(file_binary),
|
|
related_id=tool_file.id,
|
|
url=url,
|
|
storage_key=tool_file.file_key,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
|
if mime_type.startswith("image/"):
|
|
return FileType.IMAGE
|
|
if mime_type.startswith("video/"):
|
|
return FileType.VIDEO
|
|
if mime_type.startswith("audio/"):
|
|
return FileType.AUDIO
|
|
if "text" in mime_type or "pdf" in mime_type:
|
|
return FileType.DOCUMENT
|
|
return FileType.CUSTOM
|
|
|
|
@staticmethod
|
|
def fetch_structured_output_schema(
|
|
*,
|
|
structured_output: Mapping[str, Any],
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Fetch the structured output schema from the node data.
|
|
|
|
Returns:
|
|
dict[str, Any]: The structured output schema
|
|
"""
|
|
if not structured_output:
|
|
raise LLMNodeError("Please provide a valid structured output schema")
|
|
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
|
if not structured_output_schema:
|
|
raise LLMNodeError("Please provide a valid structured output schema")
|
|
|
|
try:
|
|
schema = json.loads(structured_output_schema)
|
|
if not isinstance(schema, dict):
|
|
raise LLMNodeError("structured_output_schema must be a JSON object")
|
|
return schema
|
|
except json.JSONDecodeError:
|
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
|
|
|
@staticmethod
|
|
def _save_multimodal_output_and_convert_result_to_markdown(
|
|
*,
|
|
contents: str | list[PromptMessageContentUnionTypes] | None,
|
|
file_saver: LLMFileSaver,
|
|
file_outputs: list[File],
|
|
) -> Generator[str, None, None]:
|
|
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
|
|
|
If the messages contain non-textual content (e.g., multimedia like images or videos),
|
|
it will be saved separately, and the corresponding Markdown representation will
|
|
be yielded to the caller.
|
|
"""
|
|
|
|
# NOTE(QuantumGhost): This function should yield results to the caller immediately
|
|
# whenever new content or partial content is available. Avoid any intermediate buffering
|
|
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
|
|
# if necessary.
|
|
if contents is None:
|
|
yield from []
|
|
return
|
|
if isinstance(contents, str):
|
|
yield contents
|
|
else:
|
|
for item in contents:
|
|
if isinstance(item, TextPromptMessageContent):
|
|
yield item.data
|
|
elif isinstance(item, ImagePromptMessageContent):
|
|
file = LLMNode.save_multimodal_image_output(
|
|
content=item,
|
|
file_saver=file_saver,
|
|
)
|
|
file_outputs.append(file)
|
|
yield LLMNode._image_file_to_markdown(file)
|
|
else:
|
|
logger.warning("unknown item type encountered, type=%s", type(item))
|
|
yield str(item)
|
|
|
|
@property
|
|
def retry(self) -> bool:
|
|
return self.node_data.retry_config.retry_enabled
|
|
|
|
@property
|
|
def tool_call_enabled(self) -> bool:
|
|
return (
|
|
self.node_data.tools is not None
|
|
and len(self.node_data.tools) > 0
|
|
and all(tool.enabled for tool in self.node_data.tools)
|
|
)
|
|
|
|
def _stream_llm_events(
|
|
self,
|
|
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
|
|
*,
|
|
model_instance: ModelInstance,
|
|
) -> Generator[
|
|
NodeEventBase,
|
|
None,
|
|
tuple[
|
|
str, # clean_text: processed text for outputs["text"]
|
|
str, # reasoning_content: native model reasoning
|
|
str, # generation_reasoning_content: reasoning for generation field (from <think> tags)
|
|
str, # generation_clean_content: clean text for generation field (always tag-free)
|
|
LLMUsage,
|
|
str | None,
|
|
LLMStructuredOutput | None,
|
|
LLMGenerationData | None,
|
|
],
|
|
]:
|
|
"""Stream events and capture generator return value in one place.
|
|
|
|
Uses generator delegation so _run stays concise while still emitting events.
|
|
|
|
Returns two pairs of text fields because outputs["text"] and generation["content"]
|
|
may differ when reasoning_format is "tagged":
|
|
- clean_text / reasoning_content: for top-level outputs (may keep <think> tags)
|
|
- generation_clean_content / generation_reasoning_content: for the generation field
|
|
(always tag-free, extracted via _split_reasoning with "separated" mode)
|
|
"""
|
|
clean_text = ""
|
|
reasoning_content = ""
|
|
generation_reasoning_content = ""
|
|
generation_clean_content = ""
|
|
usage = LLMUsage.empty_usage()
|
|
finish_reason: str | None = None
|
|
structured_output: LLMStructuredOutput | None = None
|
|
generation_data: LLMGenerationData | None = None
|
|
completed = False
|
|
|
|
while True:
|
|
try:
|
|
event = next(generator)
|
|
except StopIteration as exc:
|
|
if isinstance(exc.value, LLMGenerationData):
|
|
generation_data = exc.value
|
|
break
|
|
|
|
if completed:
|
|
# After completion we still drain to reach StopIteration.value
|
|
continue
|
|
|
|
match event:
|
|
case StreamChunkEvent() | ThoughtChunkEvent():
|
|
yield event
|
|
|
|
case ModelInvokeCompletedEvent(
|
|
text=text,
|
|
usage=usage_event,
|
|
finish_reason=finish_reason_event,
|
|
reasoning_content=reasoning_event,
|
|
structured_output=structured_raw,
|
|
):
|
|
clean_text = text
|
|
usage = usage_event
|
|
finish_reason = finish_reason_event
|
|
reasoning_content = reasoning_event or ""
|
|
generation_reasoning_content = reasoning_content
|
|
generation_clean_content = clean_text
|
|
|
|
if self.node_data.reasoning_format == "tagged":
|
|
# Keep tagged text for output; also extract reasoning for generation field
|
|
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
|
|
clean_text, reasoning_format="separated"
|
|
)
|
|
else:
|
|
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
|
|
clean_text, self.node_data.reasoning_format
|
|
)
|
|
generation_clean_content = clean_text
|
|
|
|
structured_output = (
|
|
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
|
|
)
|
|
|
|
from core.app.llm.quota import deduct_llm_quota
|
|
|
|
deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
|
completed = True
|
|
|
|
case LLMStructuredOutput():
|
|
structured_output = event
|
|
|
|
case _:
|
|
continue
|
|
|
|
return (
|
|
clean_text,
|
|
reasoning_content,
|
|
generation_reasoning_content,
|
|
generation_clean_content,
|
|
usage,
|
|
finish_reason,
|
|
structured_output,
|
|
generation_data,
|
|
)
|
|
|
|
def _extract_disabled_tools(self) -> dict[str, ToolDependency]:
|
|
tools = [
|
|
ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name)
|
|
for tool in self.node_data.tool_settings
|
|
if not tool.enabled
|
|
]
|
|
return {tool.tool_id(): tool for tool in tools}
|
|
|
|
def _extract_tool_dependencies(self) -> ToolDependencies | None:
|
|
"""Extract tool artifact from prompt template."""
|
|
|
|
sandbox = self.graph_runtime_state.sandbox
|
|
if not sandbox:
|
|
raise LLMNodeError("Sandbox not found")
|
|
|
|
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
|
|
tool_deps_list: list[ToolDependencies] = []
|
|
for prompt in self.node_data.prompt_template:
|
|
if isinstance(prompt, LLMNodeChatModelMessage):
|
|
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
|
|
document=SkillDocument(skill_id="anonymous", content=prompt.text, metadata=prompt.metadata or {}),
|
|
base_path=AppAssets.PATH,
|
|
)
|
|
tool_deps_list.append(skill_entry.tools)
|
|
|
|
if len(tool_deps_list) == 0:
|
|
return None
|
|
|
|
disabled_tools = self._extract_disabled_tools()
|
|
tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list)
|
|
for tool in tool_dependencies.dependencies:
|
|
if tool.tool_id() in disabled_tools:
|
|
tool.enabled = False
|
|
return tool_dependencies
|
|
|
|
def _invoke_llm_with_tools(
|
|
self,
|
|
model_instance: ModelInstance,
|
|
prompt_messages: Sequence[PromptMessage],
|
|
stop: Sequence[str] | None,
|
|
files: Sequence[File],
|
|
variable_pool: VariablePool,
|
|
node_inputs: dict[str, Any],
|
|
process_data: dict[str, Any],
|
|
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
|
"""Invoke LLM with tools support (from Agent V2).
|
|
|
|
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
|
|
"""
|
|
# Get model features to determine strategy
|
|
model_features = self._get_model_features(model_instance)
|
|
|
|
# Prepare tool instances
|
|
tool_instances = self._prepare_tool_instances(variable_pool)
|
|
|
|
# Prepare prompt files (files that come from prompt variables, not vision files)
|
|
prompt_files = self._extract_prompt_files(variable_pool)
|
|
|
|
# Use factory to create appropriate strategy
|
|
strategy = StrategyFactory.create_strategy(
|
|
model_features=model_features,
|
|
model_instance=model_instance,
|
|
tools=tool_instances,
|
|
files=prompt_files,
|
|
max_iterations=self._node_data.max_iterations or 10,
|
|
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
|
)
|
|
|
|
# Run strategy
|
|
outputs = strategy.run(
|
|
prompt_messages=list(prompt_messages),
|
|
model_parameters=self._node_data.model.completion_params,
|
|
stop=list(stop or []),
|
|
stream=True,
|
|
)
|
|
|
|
result = yield from self._process_tool_outputs(outputs)
|
|
return result
|
|
|
|
def _invoke_llm_with_sandbox(
|
|
self,
|
|
sandbox: Sandbox,
|
|
model_instance: ModelInstance,
|
|
prompt_messages: Sequence[PromptMessage],
|
|
stop: Sequence[str] | None,
|
|
variable_pool: VariablePool,
|
|
tool_dependencies: ToolDependencies | None,
|
|
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
|
result: LLMGenerationData | None = None
|
|
|
|
# FIXME(Mairuis): Async processing for bash session.
|
|
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
|
prompt_files = self._extract_prompt_files(variable_pool)
|
|
model_features = self._get_model_features(model_instance)
|
|
|
|
strategy = StrategyFactory.create_strategy(
|
|
model_features=model_features,
|
|
model_instance=model_instance,
|
|
tools=[session.bash_tool],
|
|
files=prompt_files,
|
|
max_iterations=self._node_data.max_iterations or 100,
|
|
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
|
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
|
)
|
|
|
|
outputs = strategy.run(
|
|
prompt_messages=list(prompt_messages),
|
|
model_parameters=self._node_data.model.completion_params,
|
|
stop=list(stop or []),
|
|
stream=True,
|
|
)
|
|
|
|
result = yield from self._process_tool_outputs(outputs)
|
|
|
|
# Auto-collect sandbox output/ files, deduplicate by id
|
|
collected_files = session.collect_output_files()
|
|
if collected_files:
|
|
existing_ids = {f.id for f in self._file_outputs}
|
|
self._file_outputs.extend(f for f in collected_files if f.id not in existing_ids)
|
|
|
|
if result is None:
|
|
raise LLMNodeError("SandboxSession exited unexpectedly")
|
|
|
|
return result
|
|
|
|
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
|
"""Get model schema to determine features."""
|
|
try:
|
|
model_type_instance = model_instance.model_type_instance
|
|
model_schema = model_type_instance.get_model_schema(
|
|
model_instance.model,
|
|
model_instance.credentials,
|
|
)
|
|
return model_schema.features if model_schema and model_schema.features else []
|
|
except Exception:
|
|
logger.warning("Failed to get model schema, assuming no special features")
|
|
return []
|
|
|
|
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
|
|
"""Prepare tool instances from configuration."""
|
|
tool_instances = []
|
|
|
|
if self._node_data.tools:
|
|
for tool in self._node_data.tools:
|
|
try:
|
|
# Process settings to extract the correct structure
|
|
processed_settings = {}
|
|
for key, value in tool.settings.items():
|
|
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
|
|
# Extract the nested value if it has the ToolInput structure
|
|
if "type" in value["value"] and "value" in value["value"]:
|
|
processed_settings[key] = value["value"]
|
|
else:
|
|
processed_settings[key] = value
|
|
else:
|
|
processed_settings[key] = value
|
|
|
|
# Merge parameters with processed settings (similar to Agent Node logic)
|
|
merged_parameters = {**tool.parameters, **processed_settings}
|
|
|
|
# Create AgentToolEntity from ToolMetadata
|
|
agent_tool = AgentToolEntity(
|
|
provider_id=tool.provider_name,
|
|
provider_type=tool.type,
|
|
tool_name=tool.tool_name,
|
|
tool_parameters=merged_parameters,
|
|
plugin_unique_identifier=tool.plugin_unique_identifier,
|
|
credential_id=tool.credential_id,
|
|
)
|
|
|
|
# Get tool runtime from ToolManager
|
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
|
tenant_id=self.tenant_id,
|
|
app_id=self.app_id,
|
|
agent_tool=agent_tool,
|
|
invoke_from=self.invoke_from,
|
|
variable_pool=variable_pool,
|
|
)
|
|
|
|
# Apply custom description from extra field if available
|
|
if tool.extra.get("description") and tool_runtime.entity.description:
|
|
tool_runtime.entity.description.llm = (
|
|
tool.extra.get("description") or tool_runtime.entity.description.llm
|
|
)
|
|
|
|
tool_instances.append(tool_runtime)
|
|
except Exception as e:
|
|
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
|
continue
|
|
|
|
return tool_instances
|
|
|
|
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
|
"""Extract files from prompt template variables."""
|
|
from dify_graph.variables import ArrayFileVariable, FileVariable
|
|
|
|
files: list[File] = []
|
|
|
|
# Extract variables from prompt template
|
|
if isinstance(self._node_data.prompt_template, list):
|
|
for message in self._node_data.prompt_template:
|
|
if message.text:
|
|
parser = VariableTemplateParser(message.text)
|
|
variable_selectors = parser.extract_variable_selectors()
|
|
|
|
for variable_selector in variable_selectors:
|
|
variable = variable_pool.get(variable_selector.value_selector)
|
|
if isinstance(variable, FileVariable) and variable.value:
|
|
files.append(variable.value)
|
|
elif isinstance(variable, ArrayFileVariable) and variable.value:
|
|
files.extend(variable.value)
|
|
|
|
return files
|
|
|
|
@staticmethod
|
|
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
|
|
"""Convert ToolCallResult into JSON-friendly dict."""
|
|
|
|
def _file_to_ref(file: File) -> str | None:
|
|
# Align with streamed tool result events which carry file IDs
|
|
return file.id or file.related_id
|
|
|
|
files = []
|
|
for file in tool_call.files or []:
|
|
ref = _file_to_ref(file)
|
|
if ref:
|
|
files.append(ref)
|
|
|
|
return {
|
|
"id": tool_call.id,
|
|
"name": tool_call.name,
|
|
"arguments": tool_call.arguments,
|
|
"output": tool_call.output,
|
|
"files": files,
|
|
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
|
|
"elapsed_time": tool_call.elapsed_time,
|
|
}
|
|
|
|
def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None:
|
|
"""Generate icon URL for model provider."""
|
|
from yarl import URL
|
|
|
|
from configs import dify_config
|
|
|
|
icon_type = "icon_small_dark" if dark else "icon_small"
|
|
try:
|
|
return str(
|
|
URL(dify_config.CONSOLE_API_URL or "/")
|
|
/ "console"
|
|
/ "api"
|
|
/ "workspaces"
|
|
/ self.tenant_id
|
|
/ "model-providers"
|
|
/ provider
|
|
/ icon_type
|
|
/ "en_US"
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
def _emit_model_start(self, trace_state: TraceState) -> Generator[NodeEventBase, None, None]:
|
|
"""Yield a MODEL_START event with model identity info at the beginning of a model turn.
|
|
Idempotent: only emits once per turn (guarded by trace_state.model_start_emitted)."""
|
|
if trace_state.model_start_emitted:
|
|
return
|
|
trace_state.model_start_emitted = True
|
|
if trace_state.model_segment_start_time is None:
|
|
trace_state.model_segment_start_time = time.perf_counter()
|
|
provider = self._node_data.model.provider
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "model_start"],
|
|
chunk="",
|
|
chunk_type=ChunkType.MODEL_START,
|
|
is_final=False,
|
|
model_provider=provider,
|
|
model_name=self._node_data.model.name,
|
|
model_icon=self._generate_model_provider_icon_url(provider),
|
|
model_icon_dark=self._generate_model_provider_icon_url(provider, dark=True),
|
|
)
|
|
|
|
def _flush_model_segment(
|
|
self,
|
|
buffers: StreamBuffers,
|
|
trace_state: TraceState,
|
|
error: str | None = None,
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
"""Flush pending thought/content buffers into a single model trace segment
|
|
and yield a MODEL_END chunk event with usage/duration metrics."""
|
|
if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls:
|
|
return
|
|
|
|
now = time.perf_counter()
|
|
duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0
|
|
|
|
usage = trace_state.pending_usage
|
|
|
|
provider = self._node_data.model.provider
|
|
model_name = self._node_data.model.name
|
|
model_icon = self._generate_model_provider_icon_url(provider)
|
|
model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True)
|
|
|
|
trace_state.trace_segments.append(
|
|
LLMTraceSegment(
|
|
type="model",
|
|
duration=duration,
|
|
usage=usage,
|
|
output=ModelTraceSegment(
|
|
text="".join(buffers.pending_content) if buffers.pending_content else None,
|
|
reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None,
|
|
tool_calls=list(buffers.pending_tool_calls),
|
|
),
|
|
provider=provider,
|
|
name=model_name,
|
|
icon=model_icon,
|
|
icon_dark=model_icon_dark,
|
|
error=error,
|
|
status="error" if error else "success",
|
|
)
|
|
)
|
|
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "model_end"],
|
|
chunk="",
|
|
chunk_type=ChunkType.MODEL_END,
|
|
is_final=False,
|
|
model_usage=usage,
|
|
model_duration=duration,
|
|
)
|
|
|
|
buffers.pending_thought.clear()
|
|
buffers.pending_content.clear()
|
|
buffers.pending_tool_calls.clear()
|
|
trace_state.model_segment_start_time = None
|
|
trace_state.model_start_emitted = False
|
|
trace_state.pending_usage = None
|
|
|
|
def _handle_agent_log_output(
|
|
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
payload = ToolLogPayload.from_log(output)
|
|
agent_log_event = AgentLogEvent(
|
|
message_id=output.id,
|
|
label=output.label,
|
|
node_execution_id=self.id,
|
|
parent_id=output.parent_id,
|
|
error=output.error,
|
|
status=output.status.value,
|
|
data=output.data,
|
|
metadata={k.value: v for k, v in output.metadata.items()},
|
|
node_id=self._node_id,
|
|
)
|
|
for log in agent_context.agent_logs:
|
|
if log.message_id == agent_log_event.message_id:
|
|
log.data = agent_log_event.data
|
|
log.status = agent_log_event.status
|
|
log.error = agent_log_event.error
|
|
log.label = agent_log_event.label
|
|
log.metadata = agent_log_event.metadata
|
|
break
|
|
else:
|
|
agent_context.agent_logs.append(agent_log_event)
|
|
|
|
# Handle THOUGHT log completion - capture usage for model segment
|
|
if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS:
|
|
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None
|
|
if llm_usage:
|
|
trace_state.pending_usage = llm_usage
|
|
|
|
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
|
yield from self._emit_model_start(trace_state)
|
|
|
|
tool_name = payload.tool_name
|
|
tool_call_id = payload.tool_call_id
|
|
tool_arguments = json.dumps(payload.tool_args or {})
|
|
|
|
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
|
|
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
|
|
|
|
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
|
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
|
|
|
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
|
|
|
|
yield ToolCallChunkEvent(
|
|
selector=[self._node_id, "generation", "tool_calls"],
|
|
chunk=tool_arguments,
|
|
tool_call=ToolCall(
|
|
id=tool_call_id,
|
|
name=tool_name,
|
|
arguments=tool_arguments,
|
|
icon=tool_icon,
|
|
icon_dark=tool_icon_dark,
|
|
),
|
|
is_final=False,
|
|
)
|
|
|
|
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
|
tool_name = payload.tool_name
|
|
tool_output = payload.tool_output
|
|
tool_call_id = payload.tool_call_id
|
|
tool_files = payload.files if isinstance(payload.files, list) else []
|
|
tool_error = payload.tool_error
|
|
tool_arguments = json.dumps(payload.tool_args or {})
|
|
|
|
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
|
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
|
|
|
# Flush model segment before tool result processing
|
|
yield from self._flush_model_segment(buffers, trace_state)
|
|
|
|
if output.status == AgentLog.LogStatus.ERROR:
|
|
tool_error = output.error or payload.tool_error
|
|
if not tool_error and payload.meta:
|
|
tool_error = payload.meta.get("error")
|
|
else:
|
|
if payload.meta:
|
|
meta_error = payload.meta.get("error")
|
|
if meta_error:
|
|
tool_error = meta_error
|
|
|
|
elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None
|
|
tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None
|
|
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
|
|
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
|
|
result_str = str(tool_output) if tool_output is not None else None
|
|
|
|
tool_status: Literal["success", "error"] = "error" if tool_error else "success"
|
|
tool_call_segment = LLMTraceSegment(
|
|
type="tool",
|
|
duration=elapsed_time or 0.0,
|
|
usage=None,
|
|
output=ToolTraceSegment(
|
|
id=tool_call_id,
|
|
name=tool_name,
|
|
arguments=tool_arguments,
|
|
output=result_str,
|
|
),
|
|
provider=tool_provider,
|
|
name=tool_name,
|
|
icon=tool_icon,
|
|
icon_dark=tool_icon_dark,
|
|
error=str(tool_error) if tool_error else None,
|
|
status=tool_status,
|
|
)
|
|
trace_state.trace_segments.append(tool_call_segment)
|
|
if tool_call_id:
|
|
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
|
|
|
# Start new model segment tracking
|
|
trace_state.model_segment_start_time = time.perf_counter()
|
|
|
|
yield ToolResultChunkEvent(
|
|
selector=[self._node_id, "generation", "tool_results"],
|
|
chunk=result_str or "",
|
|
tool_result=ToolResult(
|
|
id=tool_call_id,
|
|
name=tool_name,
|
|
output=result_str,
|
|
files=tool_files,
|
|
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
|
elapsed_time=elapsed_time,
|
|
icon=tool_icon,
|
|
icon_dark=tool_icon_dark,
|
|
provider=tool_provider,
|
|
),
|
|
is_final=False,
|
|
)
|
|
|
|
if buffers.current_turn_reasoning:
|
|
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
|
buffers.current_turn_reasoning.clear()
|
|
|
|
def _handle_llm_chunk_output(
|
|
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
message = output.delta.message
|
|
|
|
if message and message.content:
|
|
chunk_text = message.content
|
|
if isinstance(chunk_text, list):
|
|
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
|
|
else:
|
|
chunk_text = str(chunk_text)
|
|
|
|
for kind, segment in buffers.think_parser.process(chunk_text):
|
|
if not segment and kind not in {"thought_start", "thought_end"}:
|
|
continue
|
|
|
|
yield from self._emit_model_start(trace_state)
|
|
|
|
if kind == "thought_start":
|
|
yield ThoughtStartChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought":
|
|
buffers.current_turn_reasoning.append(segment)
|
|
buffers.pending_thought.append(segment)
|
|
yield ThoughtChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought_end":
|
|
yield ThoughtEndChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
else:
|
|
aggregate.text += segment
|
|
buffers.pending_content.append(segment)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "text"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "content"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
|
|
if output.delta.usage:
|
|
self._accumulate_usage(aggregate.usage, output.delta.usage)
|
|
|
|
if output.delta.finish_reason:
|
|
aggregate.finish_reason = output.delta.finish_reason
|
|
|
|
def _flush_remaining_stream(
|
|
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
for kind, segment in buffers.think_parser.flush():
|
|
if not segment and kind not in {"thought_start", "thought_end"}:
|
|
continue
|
|
|
|
yield from self._emit_model_start(trace_state)
|
|
|
|
if kind == "thought_start":
|
|
yield ThoughtStartChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought":
|
|
buffers.current_turn_reasoning.append(segment)
|
|
buffers.pending_thought.append(segment)
|
|
yield ThoughtChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
elif kind == "thought_end":
|
|
yield ThoughtEndChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=False,
|
|
)
|
|
else:
|
|
aggregate.text += segment
|
|
buffers.pending_content.append(segment)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "text"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "content"],
|
|
chunk=segment,
|
|
is_final=False,
|
|
)
|
|
|
|
if buffers.current_turn_reasoning:
|
|
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
|
|
|
# For final flush, use aggregate.usage if pending_usage is not set
|
|
# (e.g., for simple LLM calls without tool invocations)
|
|
if trace_state.pending_usage is None:
|
|
trace_state.pending_usage = aggregate.usage
|
|
|
|
# Flush final model segment
|
|
yield from self._flush_model_segment(buffers, trace_state)
|
|
|
|
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "text"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "content"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield ThoughtChunkEvent(
|
|
selector=[self._node_id, "generation", "thought"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield ToolCallChunkEvent(
|
|
selector=[self._node_id, "generation", "tool_calls"],
|
|
chunk="",
|
|
tool_call=ToolCall(
|
|
id="",
|
|
name="",
|
|
arguments="",
|
|
),
|
|
is_final=True,
|
|
)
|
|
yield ToolResultChunkEvent(
|
|
selector=[self._node_id, "generation", "tool_results"],
|
|
chunk="",
|
|
tool_result=ToolResult(
|
|
id="",
|
|
name="",
|
|
output="",
|
|
files=[],
|
|
status=ToolResultStatus.SUCCESS,
|
|
),
|
|
is_final=True,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "model_start"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
yield StreamChunkEvent(
|
|
selector=[self._node_id, "generation", "model_end"],
|
|
chunk="",
|
|
is_final=True,
|
|
)
|
|
|
|
def _build_generation_data(
|
|
self,
|
|
trace_state: TraceState,
|
|
agent_context: AgentContext,
|
|
aggregate: AggregatedResult,
|
|
buffers: StreamBuffers,
|
|
) -> LLMGenerationData:
|
|
sequence: list[dict[str, Any]] = []
|
|
reasoning_index = 0
|
|
content_position = 0
|
|
tool_call_seen_index: dict[str, int] = {}
|
|
for trace_segment in trace_state.trace_segments:
|
|
if trace_segment.type == "thought":
|
|
sequence.append({"type": "reasoning", "index": reasoning_index})
|
|
reasoning_index += 1
|
|
elif trace_segment.type == "content":
|
|
segment_text = trace_segment.text or ""
|
|
start = content_position
|
|
end = start + len(segment_text)
|
|
sequence.append({"type": "content", "start": start, "end": end})
|
|
content_position = end
|
|
elif trace_segment.type == "tool_call":
|
|
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
|
|
if tool_id not in tool_call_seen_index:
|
|
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
|
|
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
|
|
|
|
tool_calls_for_generation: list[ToolCallResult] = []
|
|
for log in agent_context.agent_logs:
|
|
payload = ToolLogPayload.from_mapping(log.data or {})
|
|
tool_call_id = payload.tool_call_id
|
|
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
|
|
continue
|
|
|
|
tool_args = payload.tool_args
|
|
log_error = payload.tool_error
|
|
log_output = payload.tool_output
|
|
result_text = log_output or log_error or ""
|
|
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
|
|
tool_calls_for_generation.append(
|
|
ToolCallResult(
|
|
id=tool_call_id,
|
|
name=payload.tool_name,
|
|
arguments=json.dumps(tool_args) if tool_args else "",
|
|
output=result_text,
|
|
status=status,
|
|
elapsed_time=log.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if log.metadata else None,
|
|
)
|
|
)
|
|
|
|
tool_calls_for_generation.sort(
|
|
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
|
)
|
|
|
|
return LLMGenerationData(
|
|
text=aggregate.text,
|
|
reasoning_contents=buffers.reasoning_per_turn,
|
|
tool_calls=tool_calls_for_generation,
|
|
sequence=sequence,
|
|
usage=aggregate.usage,
|
|
finish_reason=aggregate.finish_reason,
|
|
files=aggregate.files,
|
|
trace=trace_state.trace_segments,
|
|
)
|
|
|
|
def _process_tool_outputs(
|
|
self,
|
|
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
|
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
|
"""Process strategy outputs and convert to node events."""
|
|
state = ToolOutputState()
|
|
|
|
try:
|
|
for output in outputs:
|
|
if isinstance(output, AgentLog):
|
|
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
|
else:
|
|
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
|
except StopIteration as exception:
|
|
if isinstance(getattr(exception, "value", None), AgentResult):
|
|
state.agent.agent_result = exception.value
|
|
|
|
if state.agent.agent_result:
|
|
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
|
state.aggregate.files = state.agent.agent_result.files
|
|
if state.agent.agent_result.usage:
|
|
state.aggregate.usage = state.agent.agent_result.usage
|
|
if state.agent.agent_result.finish_reason:
|
|
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
|
|
|
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
|
yield from self._close_streams()
|
|
|
|
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
|
|
|
|
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
|
|
"""Accumulate LLM usage statistics."""
|
|
total_usage.prompt_tokens += delta_usage.prompt_tokens
|
|
total_usage.completion_tokens += delta_usage.completion_tokens
|
|
total_usage.total_tokens += delta_usage.total_tokens
|
|
total_usage.prompt_price += delta_usage.prompt_price
|
|
total_usage.completion_price += delta_usage.completion_price
|
|
total_usage.total_price += delta_usage.total_price
|
|
|
|
@property
|
|
def model_instance(self) -> ModelInstance:
|
|
return self._model_instance
|
|
|
|
|
|
def _combine_message_content_with_role(
|
|
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
|
):
|
|
match role:
|
|
case PromptMessageRole.USER:
|
|
return UserPromptMessage(content=contents)
|
|
case PromptMessageRole.ASSISTANT:
|
|
return AssistantPromptMessage(content=contents)
|
|
case PromptMessageRole.SYSTEM:
|
|
return SystemPromptMessage(content=contents)
|
|
case _:
|
|
raise NotImplementedError(f"Role {role} is not supported")
|
|
|
|
|
|
def _render_jinja2_message(
|
|
*,
|
|
template: str,
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
variable_pool: VariablePool,
|
|
):
|
|
if not template:
|
|
return ""
|
|
|
|
jinja2_inputs = {}
|
|
for jinja2_variable in jinja2_variables:
|
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
|
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
|
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
|
language=CodeLanguage.JINJA2,
|
|
code=template,
|
|
inputs=jinja2_inputs,
|
|
)
|
|
result_text = code_execute_resp["result"]
|
|
return result_text
|
|
|
|
|
|
def _calculate_rest_token(
|
|
*,
|
|
prompt_messages: list[PromptMessage],
|
|
model_instance: ModelInstance,
|
|
) -> int:
|
|
rest_tokens = 2000
|
|
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
|
runtime_model_parameters = model_instance.parameters
|
|
|
|
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
if model_context_tokens:
|
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
|
|
max_tokens = 0
|
|
for parameter_rule in runtime_model_schema.parameter_rules:
|
|
if parameter_rule.name == "max_tokens" or (
|
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
):
|
|
max_tokens = (
|
|
runtime_model_parameters.get(parameter_rule.name)
|
|
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
|
or 0
|
|
)
|
|
|
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
rest_tokens = max(rest_tokens, 0)
|
|
|
|
return rest_tokens
|
|
|
|
|
|
def _handle_memory_chat_mode(
|
|
*,
|
|
memory: BaseMemory | None,
|
|
memory_config: MemoryConfig | None,
|
|
model_instance: ModelInstance,
|
|
) -> Sequence[PromptMessage]:
|
|
memory_messages: Sequence[PromptMessage] = []
|
|
# Get messages from memory for chat model
|
|
if memory and memory_config:
|
|
rest_tokens = _calculate_rest_token(
|
|
prompt_messages=[],
|
|
model_instance=model_instance,
|
|
)
|
|
memory_messages = memory.get_history_prompt_messages(
|
|
max_token_limit=rest_tokens,
|
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
)
|
|
return memory_messages
|
|
|
|
|
|
def _handle_memory_completion_mode(
|
|
*,
|
|
memory: BaseMemory | None,
|
|
memory_config: MemoryConfig | None,
|
|
model_instance: ModelInstance,
|
|
) -> str:
|
|
memory_text = ""
|
|
# Get history text from memory for completion model
|
|
if memory and memory_config:
|
|
rest_tokens = _calculate_rest_token(
|
|
prompt_messages=[],
|
|
model_instance=model_instance,
|
|
)
|
|
if not memory_config.role_prefix:
|
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
|
memory_text = llm_utils.fetch_memory_text(
|
|
memory=memory,
|
|
max_token_limit=rest_tokens,
|
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
|
human_prefix=memory_config.role_prefix.user,
|
|
ai_prefix=memory_config.role_prefix.assistant,
|
|
)
|
|
return memory_text
|
|
|
|
|
|
def _handle_completion_template(
|
|
*,
|
|
template: LLMNodeCompletionModelPromptTemplate,
|
|
context: str | None,
|
|
jinja2_variables: Sequence[VariableSelector],
|
|
variable_pool: VariablePool,
|
|
) -> Sequence[PromptMessage]:
|
|
"""Handle completion template processing outside of LLMNode class.
|
|
|
|
Args:
|
|
template: The completion model prompt template
|
|
context: Optional context string
|
|
jinja2_variables: Variables for jinja2 template rendering
|
|
variable_pool: Variable pool for template conversion
|
|
|
|
Returns:
|
|
Sequence of prompt messages
|
|
"""
|
|
prompt_messages = []
|
|
if template.edition_type == "jinja2":
|
|
result_text = _render_jinja2_message(
|
|
template=template.jinja2_text or "",
|
|
jinja2_variables=jinja2_variables,
|
|
variable_pool=variable_pool,
|
|
)
|
|
else:
|
|
if context:
|
|
template_text = template.text.replace("{#context#}", context)
|
|
else:
|
|
template_text = template.text
|
|
result_text = variable_pool.convert_template(template_text).text
|
|
prompt_message = _combine_message_content_with_role(
|
|
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
|
|
)
|
|
prompt_messages.append(prompt_message)
|
|
return prompt_messages
|