from __future__ import annotations import base64 import io import json import logging import mimetypes import os import re import time from collections.abc import Generator, Mapping, Sequence from functools import reduce from pathlib import PurePosixPath from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.file_ref import ( adapt_schema_for_sandbox_file_paths, convert_sandbox_file_paths_in_output, detect_file_path_fields, ) from core.llm_generator.output_parser.structured_output import ( invoke_llm_with_structured_output, ) from core.model_manager import ModelInstance, ModelManager from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.tools.signature import sign_tool_file, sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.tool_entities import ToolCallResult from dify_graph.enums import ( BuiltinNodeTypes, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from dify_graph.file import File, FileTransferMethod, FileType, file_manager from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, ) from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, LLMStructuredOutput, LLMUsage, ) from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, PromptMessageRole, SystemPromptMessage, UserPromptMessage, ) from dify_graph.model_runtime.entities.model_entities import ( ModelFeature, ModelPropertyKey, ModelType, ) from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( AgentLogEvent, ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, ThoughtChunkEvent, ToolCallChunkEvent, ToolResultChunkEvent, ) from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import VariablePool from dify_graph.variables import ( ArrayFileSegment, ArrayPromptMessageSegment, ArraySegment, FileSegment, NoneSegment, ObjectSegment, StringSegment, ) from extensions.ext_database import db from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from . import llm_utils from .entities import ( AgentContext, AggregatedResult, LLMGenerationData, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, LLMTraceSegment, ModelConfig, ModelTraceSegment, PromptMessageContext, StreamBuffers, ThinkTagStreamParser, ToolLogPayload, ToolOutputState, ToolTraceSegment, TraceState, ) from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, TemplateTypeNotSupportError, VariableNotFoundError, ) from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.agent.entities import AgentLog, AgentResult from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.base import BaseMemory from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.sandbox import Sandbox from core.skill.entities.skill_bundle import SkillBundle from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency from core.tools.__base.tool import Tool from dify_graph.file.models import File from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) class LLMNode(Node[LLMNodeData]): node_type = BuiltinNodeTypes.LLM # Compiled regex for extracting blocks (with compatibility for attributes) _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) _llm_file_saver: LLMFileSaver def __init__( self, id: str, config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, http_client: HttpClientProtocol, credentials_provider: object | None = None, model_factory: object | None = None, model_instance: object | None = None, template_renderer: object | None = None, memory: object | None = None, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) self._file_outputs: list[File] = [] if llm_file_saver is None: dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, http_client=http_client, ) self._llm_file_saver = llm_file_saver @classmethod def version(cls) -> str: return "1" def _run(self) -> Generator: from core.sandbox.bash.session import MAX_OUTPUT_FILES node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None reasoning_content = "" # Initialize as empty string for consistency clean_text = "" # Initialize clean_text to avoid UnboundLocalError variable_pool = self.graph_runtime_state.variable_pool try: # Parse prompt template to separate static messages and context references prompt_template = self.node_data.prompt_template static_messages, context_refs, template_order = self._parse_prompt_template() # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data=self.node_data) # fetch jinja2 inputs jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) # merge inputs inputs.update(jinja_inputs) # fetch files files = ( llm_utils.fetch_files( variable_pool=variable_pool, selector=self.node_data.vision.configs.variable_selector, ) if self.node_data.vision.enabled else [] ) if files: node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] for event in generator: context = event.context context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context if context_files: node_inputs["#context_files#"] = [file.model_dump() for file in context_files] # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, tenant_id=self.tenant_id, ) resolved_completion_params = llm_utils.resolve_completion_params_variables( model_config.parameters, variable_pool, ) model_instance.parameters = resolved_completion_params model_config.parameters = resolved_completion_params self.node_data.model.completion_params = resolved_completion_params # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, tenant_id=self.tenant_id, node_data_memory=self.node_data.memory, model_instance=model_instance, node_id=self._node_id, ) query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text # Get prompt messages prompt_messages: Sequence[PromptMessage] stop: Sequence[str] | None if isinstance(prompt_template, list) and context_refs: prompt_messages, stop = self._build_prompt_messages_with_context( context_refs=context_refs, template_order=template_order, static_messages=static_messages, query=query, files=files, context=context, memory=memory, model_config=model_config, context_files=context_files, ) else: prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, memory=memory, model_config=model_config, prompt_template=cast( Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, self.node_data.prompt_template, ), memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, context_files=context_files, sandbox=self.graph_runtime_state.sandbox, ) # Variables for outputs generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None structured_output_schema: Mapping[str, Any] | None structured_output_file_paths: list[str] = [] if self.node_data.structured_output_enabled: if not self.node_data.structured_output: raise LLMNodeError("structured_output_enabled is True but structured_output is not set") raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output) if self.node_data.computer_use: raise LLMNodeError("Structured output is not supported in computer use mode.") else: if detect_file_path_fields(raw_schema): sandbox = self.graph_runtime_state.sandbox if not sandbox: raise LLMNodeError("Structured output file paths are only supported in sandbox mode.") structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths( raw_schema ) else: structured_output_schema = raw_schema else: structured_output_schema = None if self.node_data.computer_use: sandbox = self.graph_runtime_state.sandbox if not sandbox: raise LLMNodeError("computer use is enabled but no sandbox found") tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies() generator = self._invoke_llm_with_sandbox( sandbox=sandbox, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, variable_pool=variable_pool, tool_dependencies=tool_dependencies, ) elif self.tool_call_enabled: generator = self._invoke_llm_with_tools( model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, files=files, variable_pool=variable_pool, node_inputs=node_inputs, process_data=process_data, ) else: # Use traditional LLM invocation generator = LLMNode.invoke_llm( node_data_model=self._node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, structured_output_schema=structured_output_schema, allow_file_path=bool(structured_output_file_paths), file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, node_type=self.node_type, reasoning_format=self._node_data.reasoning_format, ) ( clean_text, reasoning_content, generation_reasoning_content, generation_clean_content, usage, finish_reason, structured_output, generation_data, ) = yield from self._stream_llm_events(generator, model_instance=model_instance) if structured_output and structured_output_file_paths: sandbox = self.graph_runtime_state.sandbox if not sandbox: raise LLMNodeError("Structured output file paths are only supported in sandbox mode.") structured_output_value = structured_output.structured_output if structured_output_value is None: raise LLMNodeError("Structured output is empty") resolved_count = 0 def resolve_file(path: str) -> File: nonlocal resolved_count if resolved_count >= MAX_OUTPUT_FILES: raise LLMNodeError("Structured output files exceed the sandbox output limit") resolved_count += 1 return self._resolve_sandbox_file_path(sandbox=sandbox, path=path) converted_output, structured_output_files = convert_sandbox_file_paths_in_output( output=structured_output_value, file_path_fields=structured_output_file_paths, file_resolver=resolve_file, ) structured_output = LLMStructuredOutput(structured_output=converted_output) if structured_output_files: self._file_outputs.extend(structured_output_files) # Extract variables from generation_data if available if generation_data: clean_text = generation_data.text reasoning_content = "" usage = generation_data.usage finish_reason = generation_data.finish_reason # Unified process_data building process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( model_mode=model_config.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), "finish_reason": finish_reason, "model_provider": model_config.provider, "model_name": model_config.model, } if self.tool_call_enabled and self._node_data.tools: process_data["tools"] = [ { "type": tool.type.value if hasattr(tool.type, "value") else tool.type, "provider_name": tool.provider_name, "tool_name": tool.tool_name, } for tool in self._node_data.tools if tool.enabled ] is_sandbox = self.graph_runtime_state.sandbox is not None outputs = self._build_outputs( is_sandbox=is_sandbox, clean_text=clean_text, reasoning_content=reasoning_content, generation_reasoning_content=generation_reasoning_content, generation_clean_content=generation_clean_content, usage=usage, finish_reason=finish_reason, prompt_messages=prompt_messages, generation_data=generation_data, structured_output=structured_output, ) # Send final chunk event to indicate streaming is complete # For tool calls and sandbox, final events are already sent in _process_tool_outputs if not self.tool_call_enabled and not self.node_data.computer_use: yield StreamChunkEvent( selector=[self._node_id, "text"], chunk="", is_final=True, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "content"], chunk="", is_final=True, ) yield ThoughtChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=True, ) metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, } if generation_data and generation_data.trace: metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [ segment.model_dump() for segment in generation_data.trace ] yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_data, outputs=outputs, metadata=metadata, llm_usage=usage, ) ) except ValueError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, llm_usage=usage, ) ) except Exception as e: logger.exception("error while executing llm node") yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, llm_usage=usage, ) ) def _build_outputs( self, *, is_sandbox: bool, clean_text: str, reasoning_content: str, generation_reasoning_content: str, generation_clean_content: str, usage: LLMUsage, finish_reason: str | None, prompt_messages: Sequence[PromptMessage], generation_data: LLMGenerationData | None, structured_output: LLMStructuredOutput | None, ) -> dict[str, Any]: """Build the outputs dictionary for the LLM node. Two runtime modes produce different output shapes: - **Classical** (is_sandbox=False): top-level ``text`` and ``reasoning_content`` are preserved for backward compatibility with existing users. - **Sandbox** (is_sandbox=True): ``text`` and ``reasoning_content`` are omitted from the top level because they duplicate fields inside ``generation``. The ``generation`` field always carries the full structured representation (content, reasoning, tool_calls, sequence) regardless of runtime mode. """ # Common outputs shared by both runtimes outputs: dict[str, Any] = { "usage": jsonable_encoder(usage), "finish_reason": finish_reason, "context": llm_utils.build_context(prompt_messages, clean_text, generation_data), } # Classical runtime keeps top-level text/reasoning_content for backward compatibility if not is_sandbox: outputs["text"] = clean_text outputs["reasoning_content"] = reasoning_content # Build generation field if generation_data: generation = { "content": generation_data.text, "reasoning_content": generation_data.reasoning_contents, "tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls], "sequence": generation_data.sequence, } files_to_output = list(generation_data.files) if self._file_outputs: existing_ids = {f.id for f in files_to_output} files_to_output.extend(f for f in self._file_outputs if f.id not in existing_ids) else: generation_reasoning = generation_reasoning_content or reasoning_content generation_content = generation_clean_content or clean_text sequence: list[dict[str, Any]] = [] if generation_reasoning: sequence = [ {"type": "reasoning", "index": 0}, {"type": "content", "start": 0, "end": len(generation_content)}, ] generation = { "content": generation_content, "reasoning_content": [generation_reasoning] if generation_reasoning else [], "tool_calls": [], "sequence": sequence, } files_to_output = self._file_outputs outputs["generation"] = generation if files_to_output: outputs["files"] = ArrayFileSegment(value=files_to_output) if structured_output: outputs["structured_output"] = structured_output.structured_output return outputs @staticmethod def invoke_llm( *, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, user_id: str, structured_output_schema: Mapping[str, Any] | None, allow_file_path: bool = False, file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_schema = model_instance.model_type_instance.get_model_schema( node_data_model.name, model_instance.credentials ) if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] if structured_output_schema: request_start_time = time.perf_counter() invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, json_schema=structured_output_schema, model_parameters=node_data_model.completion_params, stop=list(stop or []), user=user_id, allow_file_path=allow_file_path, ) else: request_start_time = time.perf_counter() invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, user=user_id, ) return LLMNode.handle_invoke_result( invoke_result=invoke_result, file_saver=file_saver, file_outputs=file_outputs, node_id=node_id, node_type=node_type, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @staticmethod def handle_invoke_result( *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): duration = None if request_start_time is not None: duration = time.perf_counter() - request_start_time invoke_result.usage.latency = round(duration, 3) event = LLMNode.handle_blocking_result( invoke_result=invoke_result, saver=file_saver, file_outputs=file_outputs, reasoning_format=reasoning_format, request_latency=duration, ) yield event return # For streaming mode model = "" prompt_messages: list[PromptMessage] = [] usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() think_parser = ThinkTagStreamParser() reasoning_chunks: list[str] = [] # Initialize streaming metrics tracking start_time = request_start_time if request_start_time is not None else time.perf_counter() first_token_time = None has_content = False collected_structured_output = None try: for result in invoke_result: if isinstance(result, LLMResultChunkWithStructuredOutput): if result.structured_output is not None: collected_structured_output = dict(result.structured_output) yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( contents=contents, file_saver=file_saver, file_outputs=file_outputs, ): if text_part and not has_content: first_token_time = time.perf_counter() has_content = True full_text_buffer.write(text_part) yield StreamChunkEvent( selector=[node_id, "text"], chunk=text_part, is_final=False, ) for kind, segment in think_parser.process(text_part): if not segment: if kind not in {"thought_start", "thought_end"}: continue if kind == "thought_start": yield ThoughtStartChunkEvent( selector=[node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": reasoning_chunks.append(segment) yield ThoughtChunkEvent( selector=[node_id, "generation", "thought"], chunk=segment, is_final=False, ) elif kind == "thought_end": yield ThoughtEndChunkEvent( selector=[node_id, "generation", "thought"], chunk="", is_final=False, ) else: yield StreamChunkEvent( selector=[node_id, "generation", "content"], chunk=segment, is_final=False, ) if not model and result.model: model = result.model if len(prompt_messages) == 0: prompt_messages = list(result.prompt_messages) if usage.prompt_tokens == 0 and result.delta.usage: usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason except OutputParserError as e: raise LLMNodeError(f"Failed to parse structured output: {e}") for kind, segment in think_parser.flush(): if not segment and kind not in {"thought_start", "thought_end"}: continue if kind == "thought_start": yield ThoughtStartChunkEvent( selector=[node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": reasoning_chunks.append(segment) yield ThoughtChunkEvent( selector=[node_id, "generation", "thought"], chunk=segment, is_final=False, ) elif kind == "thought_end": yield ThoughtEndChunkEvent( selector=[node_id, "generation", "thought"], chunk="", is_final=False, ) else: yield StreamChunkEvent( selector=[node_id, "generation", "content"], chunk=segment, is_final=False, ) full_text = full_text_buffer.getvalue() if reasoning_format == "tagged": clean_text = full_text reasoning_content = "".join(reasoning_chunks) else: clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) if reasoning_chunks and not reasoning_content: reasoning_content = "".join(reasoning_chunks) # Calculate streaming metrics end_time = time.perf_counter() total_duration = end_time - start_time usage.latency = round(total_duration, 3) if has_content and first_token_time: gen_ai_server_time_to_first_token = first_token_time - start_time llm_streaming_time_to_generate = end_time - first_token_time usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) usage.time_to_generate = round(llm_streaming_time_to_generate, 3) yield ModelInvokeCompletedEvent( text=clean_text if reasoning_format == "separated" else full_text, usage=usage, finish_reason=finish_reason, reasoning_content=reasoning_content, structured_output=collected_structured_output, ) @staticmethod def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @classmethod def _split_reasoning( cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" ) -> tuple[str, str]: if reasoning_format == "tagged": return text, "" matches = cls._THINK_PATTERN.findall(text) reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" clean_text = cls._THINK_PATTERN.sub("", text) clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() return clean_text, reasoning_content or "" def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages def _parse_prompt_template( self, ) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]: prompt_template = self.node_data.prompt_template static_messages: list[LLMNodeChatModelMessage] = [] context_refs: list[PromptMessageContext] = [] template_order: list[tuple[int, str]] = [] if isinstance(prompt_template, list): for idx, item in enumerate(prompt_template): if isinstance(item, PromptMessageContext): context_refs.append(item) template_order.append((idx, "context")) else: static_messages.append(item) template_order.append((idx, "static")) if static_messages: self.node_data.prompt_template = self._transform_chat_messages(static_messages) return static_messages, context_refs, template_order def _build_prompt_messages_with_context( self, *, context_refs: list[PromptMessageContext], template_order: list[tuple[int, str]], static_messages: list[LLMNodeChatModelMessage], query: str | None, files: Sequence[File], context: str | None, memory: BaseMemory | None, model_config: ModelConfigWithCredentialsEntity, context_files: list[File], ) -> tuple[list[PromptMessage], Sequence[str] | None]: variable_pool = self.graph_runtime_state.variable_pool combined_messages: list[PromptMessage] = [] context_idx = 0 static_idx = 0 for _, type_ in template_order: if type_ == "context": ctx_ref = context_refs[context_idx] ctx_var = variable_pool.get(ctx_ref.value_selector) if ctx_var is None: raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found") if not isinstance(ctx_var, ArrayPromptMessageSegment): raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]") restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value) combined_messages.extend(restored_messages) context_idx += 1 else: static_msg = static_messages[static_idx] processed_msgs = LLMNode.handle_list_messages( messages=[static_msg], context=context, jinja2_variables=self.node_data.prompt_config.jinja2_variables or [], variable_pool=variable_pool, vision_detail_config=self.node_data.vision.configs.detail, sandbox=self.graph_runtime_state.sandbox, ) combined_messages.extend(processed_msgs) static_idx += 1 memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=self.node_data.memory, model_config=model_config, ) combined_messages.extend(memory_messages) if query: query_message = LLMNodeChatModelMessage( text=query, role=PromptMessageRole.USER, edition_type="basic", ) query_msgs = LLMNode.handle_list_messages( messages=[query_message], context="", jinja2_variables=[], variable_pool=variable_pool, vision_detail_config=self.node_data.vision.configs.detail, ) combined_messages.extend(query_msgs) combined_messages = self._append_files_to_messages( messages=combined_messages, sys_files=files, context_files=context_files, model_config=model_config, ) combined_messages = self._filter_messages(combined_messages, model_config) stop = self._get_stop_sequences(model_config) return combined_messages, stop def _append_files_to_messages( self, *, messages: list[PromptMessage], sys_files: Sequence[File], context_files: list[File], model_config: ModelConfigWithCredentialsEntity, ) -> list[PromptMessage]: vision_enabled = self.node_data.vision.enabled vision_detail = self.node_data.vision.configs.detail if vision_enabled and sys_files: file_prompts = [ file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files ] if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list): messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content) else: messages.append(UserPromptMessage(content=file_prompts)) if vision_enabled and context_files: file_prompts = [ file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in context_files ] if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list): messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content) else: messages.append(UserPromptMessage(content=file_prompts)) return messages def _filter_messages( self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity ) -> list[PromptMessage]: filtered_messages: list[PromptMessage] = [] for message in messages: if isinstance(message.content, list): filtered_content: list[PromptMessageContentUnionTypes] = [] for content_item in message.content: if not model_config.model_schema.features: if content_item.type != PromptMessageContentType.TEXT: continue filtered_content.append(content_item) continue feature_map = { PromptMessageContentType.IMAGE: ModelFeature.VISION, PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT, PromptMessageContentType.VIDEO: ModelFeature.VIDEO, PromptMessageContentType.AUDIO: ModelFeature.AUDIO, } required_feature = feature_map.get(content_item.type) if required_feature and required_feature not in model_config.model_schema.features: continue filtered_content.append(content_item) if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT: message.content = filtered_content[0].data else: message.content = filtered_content if not message.is_empty(): filtered_messages.append(message) if not filtered_messages: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) return filtered_messages def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None: return model_config.stop def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables: dict[str, Any] = {} if not node_data.prompt_config: return variables for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: return str(input_dict["content"]) try: return json.dumps(input_dict, ensure_ascii=False) except Exception: return str(input_dict) if isinstance(variable, ArraySegment): result = "" for item in variable.value: if isinstance(item, dict): result += parse_dict(item) else: result += str(item) result += "\n" value = result.strip() elif isinstance(variable, ObjectSegment): value = parse_dict(variable.value) else: value = variable.text variables[variable_name] = value return variables def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: inputs = {} prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) elif isinstance(prompt_template, CompletionModelPromptTemplate): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() memory = node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() return inputs def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return if not node_data.context.variable_selector: return context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): yield RunRetrieverResourceEvent( retriever_resources=[], context=context_value_variable.value, context_files=[] ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" else: if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") if item.get("summary"): context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) .where( SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, ) ).all() if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, related_id=upload_file.id, size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), ) context_files.append(attachment_info) yield RunRetrieverResourceEvent( retriever_resources=[r.model_dump() for r in original_retriever_resource], context=context_str.strip(), context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: from core.rag.entities.citation_metadata import RetrievalSourceMetadata if ( "metadata" in context_dict and "_source" in context_dict["metadata"] and context_dict["metadata"]["_source"] == "knowledge" ): metadata = context_dict.get("metadata", {}) source = RetrievalSourceMetadata( position=metadata.get("position"), dataset_id=metadata.get("dataset_id"), dataset_name=metadata.get("dataset_name"), document_id=metadata.get("document_id"), document_name=metadata.get("document_name"), data_source_type=metadata.get("data_source_type"), segment_id=metadata.get("segment_id"), retriever_from=metadata.get("retriever_from"), score=metadata.get("score"), hit_count=metadata.get("segment_hit_count"), word_count=metadata.get("segment_word_count"), segment_position=metadata.get("segment_position"), index_node_hash=metadata.get("segment_index_node_hash"), content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), summary=context_dict.get("summary"), ) return source return None @staticmethod def _fetch_model_config( *, node_data_model: ModelConfig, tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( tenant_id=tenant_id, node_data_model=node_data_model ) completion_params = model_config_with_cred.parameters model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") model_config_with_cred.parameters = completion_params node_data_model.completion_params = completion_params return model, model_config_with_cred @staticmethod def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], context: str | None = None, memory: BaseMemory | None = None, model_config: ModelConfigWithCredentialsEntity, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, context_files: list[File] | None = None, sandbox: Sandbox | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): prompt_messages.extend( LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, variable_pool=variable_pool, vision_detail_config=vision_detail, sandbox=sandbox, ) ) memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=memory_config, model_config=model_config, ) prompt_messages.extend(memory_messages) if sys_query: message = LLMNodeChatModelMessage( text=sys_query, role=PromptMessageRole.USER, edition_type="basic", ) prompt_messages.extend( LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], variable_pool=variable_pool, vision_detail_config=vision_detail, ) ) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): prompt_messages.extend( _handle_completion_template( template=prompt_template, context=context, jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) ) memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_config=model_config, ) prompt_content = prompt_messages[0].content prompt_content_type = type(prompt_content) if prompt_content_type == str: prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content elif prompt_content_type == list: prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: content_item.data = memory_text + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") if sys_query: if prompt_content_type == str: prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content elif prompt_content_type == list: prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") else: raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) if vision_enabled and sys_files: file_prompts = [] for file in sys_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) and isinstance(prompt_messages[-1].content, list) ): prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) if vision_enabled and context_files: file_prompts = [] for file in context_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) and isinstance(prompt_messages[-1].content, list) ): prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): prompt_message_content: list[PromptMessageContentUnionTypes] = [] for content_item in prompt_message.content: if not model_config.model_schema.features: if content_item.type != PromptMessageContentType.TEXT: continue prompt_message_content.append(content_item) continue if ( ( content_item.type == PromptMessageContentType.IMAGE and ModelFeature.VISION not in model_config.model_schema.features ) or ( content_item.type == PromptMessageContentType.DOCUMENT and ModelFeature.DOCUMENT not in model_config.model_schema.features ) or ( content_item.type == PromptMessageContentType.VIDEO and ModelFeature.VIDEO not in model_config.model_schema.features ) or ( content_item.type == PromptMessageContentType.AUDIO and ModelFeature.AUDIO not in model_config.model_schema.features ) ): continue prompt_message_content.append(content_item) if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data else: prompt_message.content = prompt_message_content if prompt_message.is_empty(): continue filtered_prompt_messages.append(prompt_message) if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) model = ModelManager().get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model, ) model_schema = model.model_type_instance.get_model_schema( model=model_config.model, credentials=model.credentials, ) if not model_schema: raise ModelNotExistError(f"Model {model_config.model} not exist.") return filtered_prompt_messages, model_config.stop @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: _ = graph_config typed_node_data = node_data prompt_template = typed_node_data.prompt_template variable_selectors = [] prompt_context_selectors: list[Sequence[str]] = [] if isinstance(prompt_template, list): for item in prompt_template: if isinstance(item, PromptMessageContext): if len(item.value_selector) >= 2: prompt_context_selectors.append(item.value_selector) elif isinstance(item, LLMNodeChatModelMessage): variable_template_parser = VariableTemplateParser(template=item.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector for context_selector in prompt_context_selectors: variable_key = f"#{'.'.join(context_selector)}#" variable_mapping[variable_key] = list(context_selector) memory = typed_node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if typed_node_data.context.enabled: variable_mapping["#context#"] = typed_node_data.context.variable_selector if typed_node_data.vision.enabled: variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector if typed_node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] if typed_node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): for item in prompt_template: if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2": enable_jinja = True break else: enable_jinja = True if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "llm", "config": { "prompt_templates": { "chat_model": { "prompts": [ {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here are the chat histories between human and assistant, inside " " XML tags.\n\n\n{{" "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", "edition_type": "basic", }, "stop": ["Human:"], }, } }, } @staticmethod def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, sandbox: Sandbox | None = None, ) -> Sequence[PromptMessage]: from core.sandbox.entities.config import AppAssets from core.skill.assembler import SkillDocumentAssembler from core.skill.constants import SkillAttrs from core.skill.entities.skill_document import SkillDocument from core.skill.entities.skill_metadata import SkillMetadata prompt_messages: list[PromptMessage] = [] bundle: SkillBundle | None = None if sandbox: bundle = sandbox.attrs.get(SkillAttrs.BUNDLE) for message in messages: if message.edition_type == "jinja2": result_text = _render_jinja2_message( template=message.jinja2_text or "", jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) if bundle is not None: skill_entry = SkillDocumentAssembler(bundle).assemble_document( document=SkillDocument( skill_id="anonymous", content=result_text, metadata=SkillMetadata.model_validate(message.metadata or {}), ), base_path=AppAssets.PATH, ) result_text = skill_entry.content prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=result_text)], role=message.role ) prompt_messages.append(prompt_message) else: if context: template = message.text.replace("{#context#}", context) else: template = message.text segment_group = variable_pool.convert_template(template) file_contents = [] for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) file_contents.append(file_content) elif isinstance(segment, FileSegment): file = segment.value if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) file_contents.append(file_content) plain_text = segment_group.text if plain_text and bundle is not None: skill_entry = SkillDocumentAssembler(bundle).assemble_document( document=SkillDocument( skill_id="anonymous", content=plain_text, metadata=SkillMetadata.model_validate(message.metadata or {}), ), base_path=AppAssets.PATH, ) plain_text = skill_entry.content if plain_text: prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=plain_text)], role=message.role ) prompt_messages.append(prompt_message) if file_contents: prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) prompt_messages.append(prompt_message) return prompt_messages @staticmethod def handle_blocking_result( *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( contents=invoke_result.message.content, file_saver=saver, file_outputs=file_outputs, ): buffer.write(text_part) full_text = buffer.getvalue() if reasoning_format == "tagged": clean_text = full_text reasoning_content = "" else: clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) event = ModelInvokeCompletedEvent( text=clean_text if reasoning_format == "separated" else full_text, usage=invoke_result.usage, finish_reason=None, reasoning_content=reasoning_content, structured_output=getattr(invoke_result, "structured_output", None), ) if request_latency is not None: event.usage.latency = round(request_latency, 3) return event @staticmethod def save_multimodal_image_output( *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, ) -> File: if content.url != "": saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) return saved_file @staticmethod def _normalize_sandbox_file_path(path: str) -> str: raw = path.strip() if not raw: raise LLMNodeError("Sandbox file path must not be empty") sandbox_path = PurePosixPath(raw) if any(part == ".." for part in sandbox_path.parts): raise LLMNodeError("Sandbox file path must not contain '..'") normalized = str(sandbox_path) if normalized in {".", ""}: raise LLMNodeError("Sandbox file path is invalid") return normalized def _resolve_sandbox_file_path(self, *, sandbox: Sandbox, path: str) -> File: from core.sandbox.bash.session import MAX_OUTPUT_FILE_SIZE from core.tools.tool_file_manager import ToolFileManager normalized_path = self._normalize_sandbox_file_path(path) filename = os.path.basename(normalized_path) if not filename: raise LLMNodeError("Sandbox file path must point to a file") try: file_content = sandbox.vm.download_file(normalized_path) except Exception as exc: raise LLMNodeError(f"Sandbox file not found: {normalized_path}") from exc file_binary = file_content.getvalue() if len(file_binary) > MAX_OUTPUT_FILE_SIZE: raise LLMNodeError(f"Sandbox file exceeds size limit: {normalized_path}") mime_type, _ = mimetypes.guess_type(filename) if not mime_type: mime_type = "application/octet-stream" tool_file_manager = ToolFileManager() tool_file = tool_file_manager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, file_binary=file_binary, mimetype=mime_type, filename=filename, ) extension = os.path.splitext(filename)[1] if "." in filename else ".bin" url = sign_tool_file(tool_file.id, extension) file_type = self._get_file_type_from_mime(mime_type) return File( id=tool_file.id, tenant_id=self.tenant_id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, filename=filename, extension=extension, mime_type=mime_type, size=len(file_binary), related_id=tool_file.id, url=url, storage_key=tool_file.file_key, ) @staticmethod def _get_file_type_from_mime(mime_type: str) -> FileType: if mime_type.startswith("image/"): return FileType.IMAGE if mime_type.startswith("video/"): return FileType.VIDEO if mime_type.startswith("audio/"): return FileType.AUDIO if "text" in mime_type or "pdf" in mime_type: return FileType.DOCUMENT return FileType.CUSTOM @staticmethod def fetch_structured_output_schema( *, structured_output: Mapping[str, Any], ) -> dict[str, Any]: if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") try: schema = json.loads(structured_output_schema) if not isinstance(schema, dict): raise LLMNodeError("structured_output_schema must be a JSON object") return schema except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, file_outputs: list[File], ) -> Generator[str, None, None]: if contents is None: yield from [] return if isinstance(contents, str): yield contents else: for item in contents: if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): file = LLMNode.save_multimodal_image_output( content=item, file_saver=file_saver, ) file_outputs.append(file) yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) @property def retry(self) -> bool: return self.node_data.retry_config.retry_enabled @property def tool_call_enabled(self) -> bool: return ( self.node_data.tools is not None and len(self.node_data.tools) > 0 and all(tool.enabled for tool in self.node_data.tools) ) def _stream_llm_events( self, generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None], *, model_instance: ModelInstance, ) -> Generator[ NodeEventBase, None, tuple[ str, str, str, str, LLMUsage, str | None, LLMStructuredOutput | None, LLMGenerationData | None, ], ]: clean_text = "" reasoning_content = "" generation_reasoning_content = "" generation_clean_content = "" usage = LLMUsage.empty_usage() finish_reason: str | None = None structured_output: LLMStructuredOutput | None = None generation_data: LLMGenerationData | None = None completed = False while True: try: event = next(generator) except StopIteration as exc: if isinstance(exc.value, LLMGenerationData): generation_data = exc.value break if completed: continue match event: case StreamChunkEvent() | ThoughtChunkEvent(): yield event case ModelInvokeCompletedEvent( text=text, usage=usage_event, finish_reason=finish_reason_event, reasoning_content=reasoning_event, structured_output=structured_raw, ): clean_text = text usage = usage_event finish_reason = finish_reason_event reasoning_content = reasoning_event or "" generation_reasoning_content = reasoning_content generation_clean_content = clean_text if self.node_data.reasoning_format == "tagged": generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning( clean_text, reasoning_format="separated" ) else: clean_text, generation_reasoning_content = LLMNode._split_reasoning( clean_text, self.node_data.reasoning_format ) generation_clean_content = clean_text structured_output = ( LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None ) llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) completed = True case LLMStructuredOutput(): structured_output = event case _: continue return ( clean_text, reasoning_content, generation_reasoning_content, generation_clean_content, usage, finish_reason, structured_output, generation_data, ) def _extract_disabled_tools(self) -> dict[str, ToolDependency]: from core.skill.entities.tool_dependencies import ToolDependency tools = [ ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name) for tool in self.node_data.tool_settings if not tool.enabled ] return {tool.tool_id(): tool for tool in tools} def _extract_tool_dependencies(self) -> ToolDependencies | None: from core.sandbox.entities.config import AppAssets from core.skill.assembler import SkillDocumentAssembler from core.skill.constants import SkillAttrs from core.skill.entities.skill_document import SkillDocument from core.skill.entities.skill_metadata import SkillMetadata sandbox = self.graph_runtime_state.sandbox if not sandbox: raise LLMNodeError("Sandbox not found") bundle = sandbox.attrs.get(SkillAttrs.BUNDLE) tool_deps_list: list[ToolDependencies] = [] for prompt in self.node_data.prompt_template: if isinstance(prompt, LLMNodeChatModelMessage): skill_entry = SkillDocumentAssembler(bundle).assemble_document( document=SkillDocument( skill_id="anonymous", content=prompt.text, metadata=SkillMetadata.model_validate(prompt.metadata or {}), ), base_path=AppAssets.PATH, ) tool_deps_list.append(skill_entry.tools) if len(tool_deps_list) == 0: return None disabled_tools = self._extract_disabled_tools() tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list) for tool in tool_dependencies.dependencies: if tool.tool_id() in disabled_tools: tool.enabled = False return tool_dependencies def _invoke_llm_with_tools( self, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, files: Sequence[File], variable_pool: VariablePool, node_inputs: dict[str, Any], process_data: dict[str, Any], ) -> Generator[NodeEventBase, None, LLMGenerationData]: from core.agent.entities import ExecutionContext from core.agent.patterns import StrategyFactory model_features = self._get_model_features(model_instance) tool_instances = self._prepare_tool_instances(variable_pool) prompt_files = self._extract_prompt_files(variable_pool) strategy = StrategyFactory.create_strategy( model_features=model_features, model_instance=model_instance, tools=tool_instances, files=prompt_files, max_iterations=self._node_data.max_iterations or 10, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), ) outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), stream=True, ) result = yield from self._process_tool_outputs(outputs) return result def _invoke_llm_with_sandbox( self, sandbox: Sandbox, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, variable_pool: VariablePool, tool_dependencies: ToolDependencies | None, ) -> Generator[NodeEventBase, None, LLMGenerationData]: from core.agent.entities import AgentEntity, ExecutionContext from core.agent.patterns import StrategyFactory from core.sandbox.bash.session import SandboxBashSession result: LLMGenerationData | None = None with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) strategy = StrategyFactory.create_strategy( model_features=model_features, model_instance=model_instance, tools=[session.bash_tool], files=prompt_files, max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), ) outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), stream=True, ) result = yield from self._process_tool_outputs(outputs) collected_files = session.collect_output_files() if collected_files: existing_ids = {f.id for f in self._file_outputs} self._file_outputs.extend(f for f in collected_files if f.id not in existing_ids) if result is None: raise LLMNodeError("SandboxSession exited unexpectedly") return result def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: try: model_type_instance = model_instance.model_type_instance model_schema = model_type_instance.get_model_schema( model_instance.model_name, model_instance.credentials, ) return model_schema.features if model_schema and model_schema.features else [] except Exception: logger.warning("Failed to get model schema, assuming no special features") return [] def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]: from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager tool_instances = [] if self._node_data.tools: for tool in self._node_data.tools: try: processed_settings = {} for key, value in tool.settings.items(): if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict): if "type" in value["value"] and "value" in value["value"]: processed_settings[key] = value["value"] else: processed_settings[key] = value else: processed_settings[key] = value merged_parameters = {**tool.parameters, **processed_settings} agent_tool = AgentToolEntity( provider_id=tool.provider_name, provider_type=tool.type, tool_name=tool.tool_name, tool_parameters=merged_parameters, plugin_unique_identifier=tool.plugin_unique_identifier, credential_id=tool.credential_id, ) tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_id, agent_tool=agent_tool, invoke_from=self.invoke_from, variable_pool=variable_pool, ) if tool.extra.get("description") and tool_runtime.entity.description: tool_runtime.entity.description.llm = ( tool.extra.get("description") or tool_runtime.entity.description.llm ) tool_instances.append(tool_runtime) except Exception as e: logger.warning("Failed to load tool %s: %s", tool, str(e)) continue return tool_instances def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]: from dify_graph.variables.variables import ArrayFileVariable, FileVariable files: list[File] = [] if isinstance(self._node_data.prompt_template, list): for message in self._node_data.prompt_template: if message.text: parser = VariableTemplateParser(message.text) variable_selectors = parser.extract_variable_selectors() for variable_selector in variable_selectors: variable = variable_pool.get(variable_selector.value_selector) if isinstance(variable, FileVariable) and variable.value: files.append(variable.value) elif isinstance(variable, ArrayFileVariable) and variable.value: files.extend(variable.value) return files @staticmethod def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]: def _file_to_ref(file: File) -> str | None: return file.id or file.related_id files = [] for file in tool_call.files or []: ref = _file_to_ref(file) if ref: files.append(ref) return { "id": tool_call.id, "name": tool_call.name, "arguments": tool_call.arguments, "output": tool_call.output, "files": files, "status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status, "elapsed_time": tool_call.elapsed_time, } def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None: from yarl import URL from configs import dify_config icon_type = "icon_small_dark" if dark else "icon_small" try: return str( URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / self.tenant_id / "model-providers" / provider / icon_type / "en_US" ) except Exception: return None def _emit_model_start(self, trace_state: TraceState) -> Generator[NodeEventBase, None, None]: if trace_state.model_start_emitted: return trace_state.model_start_emitted = True if trace_state.model_segment_start_time is None: trace_state.model_segment_start_time = time.perf_counter() provider = self._node_data.model.provider yield StreamChunkEvent( selector=[self._node_id, "generation", "model_start"], chunk="", chunk_type=ChunkType.MODEL_START, is_final=False, model_provider=provider, model_name=self._node_data.model.name, model_icon=self._generate_model_provider_icon_url(provider), model_icon_dark=self._generate_model_provider_icon_url(provider, dark=True), ) def _flush_model_segment( self, buffers: StreamBuffers, trace_state: TraceState, error: str | None = None, ) -> Generator[NodeEventBase, None, None]: if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls: return now = time.perf_counter() duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0 usage = trace_state.pending_usage provider = self._node_data.model.provider model_name = self._node_data.model.name model_icon = self._generate_model_provider_icon_url(provider) model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True) trace_state.trace_segments.append( LLMTraceSegment( type="model", duration=duration, usage=usage, output=ModelTraceSegment( text="".join(buffers.pending_content) if buffers.pending_content else None, reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None, tool_calls=list(buffers.pending_tool_calls), ), provider=provider, name=model_name, icon=model_icon, icon_dark=model_icon_dark, error=error, status="error" if error else "success", ) ) yield StreamChunkEvent( selector=[self._node_id, "generation", "model_end"], chunk="", chunk_type=ChunkType.MODEL_END, is_final=False, model_usage=usage, model_duration=duration, ) buffers.pending_thought.clear() buffers.pending_content.clear() buffers.pending_tool_calls.clear() trace_state.model_segment_start_time = None trace_state.model_start_emitted = False trace_state.pending_usage = None def _handle_agent_log_output( self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext ) -> Generator[NodeEventBase, None, None]: from core.agent.entities import AgentLog payload = ToolLogPayload.from_log(output) agent_log_event = AgentLogEvent( message_id=output.id, label=output.label, node_execution_id=self.id, parent_id=output.parent_id, error=output.error, status=output.status.value, data=output.data, metadata={k.value: v for k, v in output.metadata.items()}, node_id=self._node_id, ) for log in agent_context.agent_logs: if log.message_id == agent_log_event.message_id: log.data = agent_log_event.data log.status = agent_log_event.status log.error = agent_log_event.error log.label = agent_log_event.label log.metadata = agent_log_event.metadata break else: agent_context.agent_logs.append(agent_log_event) if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS: llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None if llm_usage: trace_state.pending_usage = llm_usage if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: yield from self._emit_model_start(trace_state) tool_name = payload.tool_name tool_call_id = payload.tool_call_id tool_arguments = json.dumps(payload.tool_args or {}) tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments)) yield ToolCallChunkEvent( selector=[self._node_id, "generation", "tool_calls"], chunk=tool_arguments, tool_call=ToolCall( id=tool_call_id, name=tool_name, arguments=tool_arguments, icon=tool_icon, icon_dark=tool_icon_dark, ), is_final=False, ) if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: tool_name = payload.tool_name tool_output = payload.tool_output tool_call_id = payload.tool_call_id tool_files = payload.files if isinstance(payload.files, list) else [] tool_error = payload.tool_error tool_arguments = json.dumps(payload.tool_args or {}) if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) yield from self._flush_model_segment(buffers, trace_state) if output.status == AgentLog.LogStatus.ERROR: tool_error = output.error or payload.tool_error if not tool_error and payload.meta: tool_error = payload.meta.get("error") else: if payload.meta: meta_error = payload.meta.get("error") if meta_error: tool_error = meta_error elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None result_str = str(tool_output) if tool_output is not None else None tool_status: Literal["success", "error"] = "error" if tool_error else "success" tool_call_segment = LLMTraceSegment( type="tool", duration=elapsed_time or 0.0, usage=None, output=ToolTraceSegment( id=tool_call_id, name=tool_name, arguments=tool_arguments, output=result_str, ), provider=tool_provider, name=tool_name, icon=tool_icon, icon_dark=tool_icon_dark, error=str(tool_error) if tool_error else None, status=tool_status, ) trace_state.trace_segments.append(tool_call_segment) if tool_call_id: trace_state.tool_trace_map[tool_call_id] = tool_call_segment trace_state.model_segment_start_time = time.perf_counter() yield ToolResultChunkEvent( selector=[self._node_id, "generation", "tool_results"], chunk=result_str or "", tool_result=ToolResult( id=tool_call_id, name=tool_name, output=result_str, files=tool_files, status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, elapsed_time=elapsed_time, icon=tool_icon, icon_dark=tool_icon_dark, provider=tool_provider, ), is_final=False, ) if buffers.current_turn_reasoning: buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) buffers.current_turn_reasoning.clear() def _handle_llm_chunk_output( self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult ) -> Generator[NodeEventBase, None, None]: message = output.delta.message if message and message.content: chunk_text = message.content if isinstance(chunk_text, list): chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text) else: chunk_text = str(chunk_text) for kind, segment in buffers.think_parser.process(chunk_text): if not segment and kind not in {"thought_start", "thought_end"}: continue yield from self._emit_model_start(trace_state) if kind == "thought_start": yield ThoughtStartChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": buffers.current_turn_reasoning.append(segment) buffers.pending_thought.append(segment) yield ThoughtChunkEvent( selector=[self._node_id, "generation", "thought"], chunk=segment, is_final=False, ) elif kind == "thought_end": yield ThoughtEndChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) else: aggregate.text += segment buffers.pending_content.append(segment) yield StreamChunkEvent( selector=[self._node_id, "text"], chunk=segment, is_final=False, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "content"], chunk=segment, is_final=False, ) if output.delta.usage: self._accumulate_usage(aggregate.usage, output.delta.usage) if output.delta.finish_reason: aggregate.finish_reason = output.delta.finish_reason def _flush_remaining_stream( self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult ) -> Generator[NodeEventBase, None, None]: for kind, segment in buffers.think_parser.flush(): if not segment and kind not in {"thought_start", "thought_end"}: continue yield from self._emit_model_start(trace_state) if kind == "thought_start": yield ThoughtStartChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) elif kind == "thought": buffers.current_turn_reasoning.append(segment) buffers.pending_thought.append(segment) yield ThoughtChunkEvent( selector=[self._node_id, "generation", "thought"], chunk=segment, is_final=False, ) elif kind == "thought_end": yield ThoughtEndChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=False, ) else: aggregate.text += segment buffers.pending_content.append(segment) yield StreamChunkEvent( selector=[self._node_id, "text"], chunk=segment, is_final=False, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "content"], chunk=segment, is_final=False, ) if buffers.current_turn_reasoning: buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) if trace_state.pending_usage is None: trace_state.pending_usage = aggregate.usage yield from self._flush_model_segment(buffers, trace_state) def _close_streams(self) -> Generator[NodeEventBase, None, None]: yield StreamChunkEvent( selector=[self._node_id, "text"], chunk="", is_final=True, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "content"], chunk="", is_final=True, ) yield ThoughtChunkEvent( selector=[self._node_id, "generation", "thought"], chunk="", is_final=True, ) yield ToolCallChunkEvent( selector=[self._node_id, "generation", "tool_calls"], chunk="", tool_call=ToolCall( id="", name="", arguments="", ), is_final=True, ) yield ToolResultChunkEvent( selector=[self._node_id, "generation", "tool_results"], chunk="", tool_result=ToolResult( id="", name="", output="", files=[], status=ToolResultStatus.SUCCESS, ), is_final=True, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "model_start"], chunk="", is_final=True, ) yield StreamChunkEvent( selector=[self._node_id, "generation", "model_end"], chunk="", is_final=True, ) def _build_generation_data( self, trace_state: TraceState, agent_context: AgentContext, aggregate: AggregatedResult, buffers: StreamBuffers, ) -> LLMGenerationData: from core.agent.entities import AgentLog sequence: list[dict[str, Any]] = [] reasoning_index = 0 content_position = 0 tool_call_seen_index: dict[str, int] = {} for trace_segment in trace_state.trace_segments: if trace_segment.type == "thought": sequence.append({"type": "reasoning", "index": reasoning_index}) reasoning_index += 1 elif trace_segment.type == "content": segment_text = trace_segment.text or "" start = content_position end = start + len(segment_text) sequence.append({"type": "content", "start": start, "end": end}) content_position = end elif trace_segment.type == "tool_call": tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else "" if tool_id not in tool_call_seen_index: tool_call_seen_index[tool_id] = len(tool_call_seen_index) sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) tool_calls_for_generation: list[ToolCallResult] = [] for log in agent_context.agent_logs: payload = ToolLogPayload.from_mapping(log.data or {}) tool_call_id = payload.tool_call_id if not tool_call_id or log.status == AgentLog.LogStatus.START.value: continue tool_args = payload.tool_args log_error = payload.tool_error log_output = payload.tool_output result_text = log_output or log_error or "" status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS tool_calls_for_generation.append( ToolCallResult( id=tool_call_id, name=payload.tool_name, arguments=json.dumps(tool_args) if tool_args else "", output=result_text, status=status, elapsed_time=log.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if log.metadata else None, ) ) tool_calls_for_generation.sort( key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map)) ) return LLMGenerationData( text=aggregate.text, reasoning_contents=buffers.reasoning_per_turn, tool_calls=tool_calls_for_generation, sequence=sequence, usage=aggregate.usage, finish_reason=aggregate.finish_reason, files=aggregate.files, trace=trace_state.trace_segments, ) def _process_tool_outputs( self, outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], ) -> Generator[NodeEventBase, None, LLMGenerationData]: from core.agent.entities import AgentLog, AgentResult state = ToolOutputState() try: for output in outputs: if isinstance(output, AgentLog): yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) else: yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) except StopIteration as exception: if isinstance(getattr(exception, "value", None), AgentResult): state.agent.agent_result = exception.value if state.agent.agent_result: state.aggregate.text = state.agent.agent_result.text or state.aggregate.text state.aggregate.files = state.agent.agent_result.files if state.agent.agent_result.usage: state.aggregate.usage = state.agent.agent_result.usage if state.agent.agent_result.finish_reason: state.aggregate.finish_reason = state.agent.agent_result.finish_reason yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) yield from self._close_streams() return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream) def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None: total_usage.prompt_tokens += delta_usage.prompt_tokens total_usage.completion_tokens += delta_usage.completion_tokens total_usage.total_tokens += delta_usage.total_tokens total_usage.prompt_price += delta_usage.prompt_price total_usage.completion_price += delta_usage.completion_price total_usage.total_price += delta_usage.total_price def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: return UserPromptMessage(content=contents) case PromptMessageRole.ASSISTANT: return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) case _: raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( *, template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ): from core.helper.code_executor import CodeExecutor, CodeLanguage if not template: return "" jinja2_inputs = {} for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" code_execute_resp = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=template, inputs=jinja2_inputs, ) result_text = code_execute_resp["result"] return result_text def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( model_config.parameters.get(parameter_rule.name) or model_config.parameters.get(str(parameter_rule.use_template)) or 0 ) rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens def _handle_memory_chat_mode( *, memory: BaseMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, ) return memory_messages def _handle_memory_completion_mode( *, memory: BaseMemory | None, memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> str: memory_text = "" if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") memory_text = memory.get_history_prompt_text( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, ) return memory_text def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: prompt_messages = [] if template.edition_type == "jinja2": result_text = _render_jinja2_message( template=template.jinja2_text or "", jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) else: if context: template_text = template.text.replace("{#context#}", context) else: template_text = template.text result_text = variable_pool.convert_template(template_text).text prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER ) prompt_messages.append(prompt_message) return prompt_messages